Update regresssion test to parse eval result from json (#1310)

* export metrics results to json

* fix mmedit

* update docs

* fix test failure

* fix

* fix mmocr metrics

* remove srgan config with no set5 test
pull/1420/head
RunningLeon 2022-11-22 20:47:22 +08:00 committed by GitHub
parent b23411d907
commit de96f51231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 223 additions and 461 deletions

View File

@ -50,6 +50,7 @@ ${MODEL_CFG} \
- `--speed-test`: Whether to activate speed test.
- `--warmup`: warmup before counting inference elapse, require setting speed-test first.
- `--log-interval`: The interval between each log, require setting speed-test first.
- `--json-file`: The path of json file to save evaluation results. Default is `./results.json`.
\* Other arguments in `tools/test.py` are used for speed test. They have no concern with evaluation.

View File

@ -228,7 +228,8 @@ class Classification(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None) -> None:
log_file: Optional[str] = None,
json_file: Optional[str] = None) -> None:
"""Perform post-processing to predictions of model.
Args:
@ -249,9 +250,11 @@ class Classification(BaseTask):
"""
from mmcv.utils import get_logger
logger = get_logger('test', log_file=log_file, log_level=logging.INFO)
if metrics:
results = dataset.evaluate(outputs, metrics, metric_options)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
for k, v in results.items():
logger.info(f'{k} : {v:.2f}')
else:

View File

@ -249,7 +249,8 @@ class ObjectDetection(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
@ -287,7 +288,10 @@ class ObjectDetection(BaseTask):
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
logger.info(dataset.evaluate(outputs, **eval_kwargs))
results = dataset.evaluate(outputs, **eval_kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
logger.info(results)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -178,7 +178,8 @@ class VoxelDetection(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
if out:
logger = get_root_logger()
logger.info(f'\nwriting results to {out}')
@ -196,7 +197,10 @@ class VoxelDetection(BaseTask):
eval_kwargs.pop(key, None)
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
dataset.evaluate(outputs, **eval_kwargs)
results = dataset.evaluate(outputs, **eval_kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
logger.info(results)
def get_model_name(self) -> str:
"""Get the model name.

View File

@ -257,6 +257,7 @@ class SuperResolution(BaseTask):
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None,
json_file: Optional[str] = None,
**kwargs) -> None:
"""Evaluation function implemented in mmedit.
@ -287,6 +288,8 @@ class SuperResolution(BaseTask):
stats = dataset.evaluate(outputs)
for stat in stats:
logger.info('Eval-{}: {}'.format(stat, stats[stat]))
if json_file is not None:
mmcv.dump(stats, json_file, indent=4)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -241,7 +241,8 @@ class TextDetection(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
@ -279,7 +280,10 @@ class TextDetection(BaseTask):
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
logger.info(dataset.evaluate(outputs, **eval_kwargs))
results = dataset.evaluate(outputs, **eval_kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
logger.info(results)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -255,7 +255,8 @@ class TextRecognition(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
@ -293,7 +294,10 @@ class TextRecognition(BaseTask):
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
logger.info(dataset.evaluate(outputs, **eval_kwargs))
results = dataset.evaluate(outputs, **eval_kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
logger.info(results)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -272,6 +272,7 @@ class PoseDetection(BaseTask):
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None,
json_file: Optional[str] = None,
**kwargs):
"""Perform post-processing to predictions of model.
@ -307,6 +308,8 @@ class PoseDetection(BaseTask):
eval_config.update(dict(metric=metrics))
results = dataset.evaluate(outputs, res_folder, **eval_config)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
for k, v in sorted(results.items()):
logger.info(f'{k}: {v:.4f}')

View File

@ -284,7 +284,8 @@ class RotatedDetection(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
@ -322,7 +323,10 @@ class RotatedDetection(BaseTask):
]:
eval_kwargs.pop(key, None)
eval_kwargs.update(dict(metric=metrics, **kwargs))
logger.info(dataset.evaluate(outputs, **eval_kwargs))
results = dataset.evaluate(outputs, **eval_kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
logger.info(results)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -227,7 +227,8 @@ class Segmentation(BaseTask):
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None):
log_file: Optional[str] = None,
json_file: Optional[str] = None):
"""Perform post-processing to predictions of model.
Args:
@ -257,7 +258,10 @@ class Segmentation(BaseTask):
if format_only:
dataset.format_results(outputs, **kwargs)
if metrics:
dataset.evaluate(outputs, metrics, logger=logger, **kwargs)
results = dataset.evaluate(
outputs, metrics, logger=logger, **kwargs)
if json_file is not None:
mmcv.dump(results, json_file, indent=4)
def get_preprocess(self) -> Dict:
"""Get the preprocess information for SDK.

View File

@ -14,14 +14,10 @@ globals:
eval_name: accuracy # test.py --metrics args
metric_key: accuracy_top-1 # eval Dict key name
tolerance: 1 # metric ±n%
task_name: Image Classification # metafile.Results.Task
dataset: ImageNet-1k # metafile.Results.Dataset
Top 5 Accuracy:
eval_name: accuracy
metric_key: accuracy_top-5
tolerance: 1 # metric ±n%
task_name: Image Classification
dataset: ImageNet-1k
convert_image: &convert_image
input_img: *img_snake
test_img: *img_color_cat

View File

@ -4,26 +4,21 @@ globals:
images:
input_img: &input_img ../mmdetection/demo/demo.jpg
test_img: &test_img ./tests/data/tiger.jpeg
img_blank: &img_blank
metric_info: &metric_info
box AP: # named after metafile.Results.Metrics
eval_name: bbox # test.py --metrics args
metric_key: bbox_mAP # eval OrderedDict key name
tolerance: 0.2 # metric ±n%
task_name: Object Detection # metafile.Results.Task
dataset: COCO # metafile.Results.Dataset
multi_value: 100
mask AP:
eval_name: segm
metric_key: segm_mAP
tolerance: 1 # metric ±n%
task_name: Instance Segmentation
dataset: COCO
multi_value: 100
PQ:
eval_name: proposal
metric_key: '?'
tolerance: 0.1 # metric ±n%
task_name: Panoptic Segmentation
dataset: COCO
convert_image: &convert_image
input_img: *input_img
test_img: *test_img

View File

@ -10,14 +10,10 @@ globals:
eval_name: bbox # test.py --metrics args
metric_key: bbox_mAP # eval OrderedDict key name
tolerance: 1 # metric ±n%
task_name: 3D Object Detection # metafile.Results.Task
dataset: KITTI # metafile.Results.Dataset
mAP:
eval_name: bbox
metric_key: bbox_mAP
tolerance: 1 # metric ±n%
task_name: 3D Object Detection
dataset: nuScenes
NDS:
eval_name: bbox
metric_key: bbox_mAP

View File

@ -6,16 +6,12 @@ globals:
img_bg: &img_bg ../mmediting/tests/data/gt/baboon.png
metric_info: &metric_info
PSNR: # named after metafile.Results.Metrics
eval_name: PSNR # test.py --metrics args
metric_key: Eval-PSNR # eval log key name
tolerance: 4 # metric ±n%
task_name: Restorers # metafile.Results.Task
dataset: Set5 # metafile.Results.Dataset
metric_key: PSNR # eval log key name
tolerance: 0.2 # metric ±n%
dataset: Set5
SSIM:
eval_name: SSIM
metric_key: Eval-SSIM
metric_key: SSIM
tolerance: 0.02 # metric ±n
task_name: Restorers
dataset: Set5
convert_image: &convert_image
input_img: *img_face
@ -125,7 +121,6 @@ models:
- name: SRGAN
metafile: configs/restorers/srresnet_srgan/metafile.yml
model_configs:
- configs/restorers/srresnet_srgan/srgan_x4c64b16_g1_1000k_div2k.py
- configs/restorers/srresnet_srgan/msrresnet_x4c64b16_g1_1000k_div2k.py
pipelines:
- *pipeline_ts_fp32

View File

@ -10,15 +10,12 @@ globals:
hmean-iou: # named after metafile.Results.Metrics
eval_name: hmean-iou # test.py --metrics args
metric_key: 0_hmean-iou:hmean # eval key name
tolerance: 0.15 # metric ±n%
task_name: Text Detection # metafile.Results.Task
dataset: ICDAR2015 # metafile.Results.Dataset
tolerance: 0.01 # metric ±n%
word_acc:
eval_name: acc
metric_key: 0_word_acc_ignore_case
tolerance: 0.05 # metric ±n%
task_name: Text Recognition
dataset: IIIT5K
tolerance: 1.0 # metric
multi_value: 100
convert_image_det: &convert_image_det
input_img: *img_densetext_det
test_img: *img_demo_text_det

View File

@ -4,20 +4,15 @@ globals:
images:
img_human_pose: &img_human_pose ../mmpose/tests/data/coco/000000000785.jpg
img_human_pose_256x192: &img_human_pose_256x192 ./demo/resources/human-pose.jpg
img_blank: &img_blank
metric_info: &metric_info
AP: # named after metafile.Results.Metrics
eval_name: mAP # test.py --metrics args
metric_key: AP # eval key name
tolerance: 0.10 # metric ±n
task_name: Body 2D Keypoint # metafile.Results.Task
dataset: COCO # metafile.Results.Dataset
AR:
eval_name: mAP
metric_key: AR
tolerance: 0.08 # metric ±n
task_name: Body 2D Keypoint
dataset: COCO
convert_image: &convert_image
input_img: *img_human_pose
test_img: *img_human_pose_256x192

View File

@ -9,8 +9,6 @@ globals:
eval_name: mAP # test.py --metrics args
metric_key: AP # eval key name
tolerance: 0.10 # metric ±n%
task_name: Oriented Object Detection # metafile.Results.Task
dataset: DOTAv1.0 # metafile.Results.Dataset
convert_image_det: &convert_image_det
input_img: *img_demo
test_img: *img_dota_demo

View File

@ -12,9 +12,8 @@ globals:
mIoU: # named after metafile.Results.Metrics
eval_name: mIoU # test.py --metrics args
metric_key: mIoU # eval OrderedDict key name
tolerance: 5 # metric ±n%
task_name: Semantic Segmentation # metafile.Results.Task
dataset: [Cityscapes, ADE20K] # metafile.Results.Dataset
tolerance: 1 # metric ±n%
multi_value: 100
convert_image: &convert_image
input_img: *img_leftImg8bit
test_img: *img_loveda_0

View File

@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
import os
import subprocess
from collections import OrderedDict
from pathlib import Path
from typing import List
@ -13,9 +13,9 @@ import yaml
from torch.hub import download_url_to_file
from torch.multiprocessing import set_start_method
import mmdeploy.version
import mmdeploy
from mmdeploy.utils import (get_backend, get_codebase, get_root_logger,
is_dynamic_shape, load_config)
is_dynamic_shape)
def parse_args():
@ -71,7 +71,7 @@ def merge_report(work_dir: str, logger: logging.Logger):
"""
work_dir = Path(work_dir)
res_file = work_dir.joinpath(
f'mmdeploy_regression_test_{mmdeploy.version.__version__}.xlsx')
f'mmdeploy_regression_test_{mmdeploy.__version__}.xlsx')
logger.info(f'Whole result report saving in {res_file}')
if res_file.exists():
@ -300,83 +300,43 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
return {}
# get metric
model_info = meta_info.get(model_config_name, None)
metafile_metric_info = model_info.get('Results', None)
metric_list = []
model_info = meta_info[model_config_name]
metafile_metric_info = model_info['Results']
pytorch_metric = dict()
dataset_type = ''
task_type = ''
# Get dataset
using_dataset = dict()
for _, v in test_yaml_metric_info.items():
if v.get('dataset') is None:
continue
dataset_list = v.get('dataset', [])
if not isinstance(dataset_list, list):
dataset_list = [dataset_list]
for metric_dataset in dataset_list:
dataset_tmp = using_dataset.get(metric_dataset, [])
if v.get('task_name') not in dataset_tmp:
dataset_tmp.append(v.get('task_name'))
using_dataset.update({metric_dataset: dataset_tmp})
using_dataset = set()
using_task = set()
configured_dataset = set()
for items in test_yaml_metric_info.values():
if 'dataset' in items:
configured_dataset.add(items['dataset'])
# Get metrics info from metafile
for metafile_metric in metafile_metric_info:
pytorch_meta_metric = metafile_metric.get('Metrics')
dataset = metafile_metric.get('Dataset', '')
task_name = metafile_metric.get('Task', '')
if task_name == 'Restorers':
# mmedit
dataset = 'Set5'
if (len(using_dataset) > 1) and (dataset not in using_dataset):
logger.info(f'dataset not in {using_dataset}, skip it...')
continue
dataset_type += f'{dataset} | '
if task_name not in using_dataset.get(dataset, []):
# only add the metric with the correct dataset
logger.info(f'task_name ({task_name}) is not in'
f'{using_dataset.get(dataset, [])}, skip it...')
continue
task_type += f'{task_name} | '
# remove some metric which not in metric_info from test yaml
for k, v in pytorch_meta_metric.items():
if k not in test_yaml_metric_info and \
'Restorers' not in task_type:
continue
if 'Restorers' in task_type and k not in dataset_type:
# mmedit
continue
if isinstance(v, dict):
# mmedit
for sub_k, sub_v in v.items():
use_metric = {sub_k: sub_v}
metric_list.append(use_metric)
pytorch_metric.update(use_metric)
else:
use_metric = {k: v}
metric_list.append(use_metric)
pytorch_metric.update(use_metric)
dataset_type = dataset_type[:-3].upper() # remove the final ' | '
task_type = task_type[:-3] # remove the final ' | '
# update useless metric
metric_all_list = [str(metric) for metric in test_yaml_metric_info]
metric_useless = set(metric_all_list) - set(
[str(metric) for metric in pytorch_metric])
for metric in metric_useless:
metric_list.append({metric: '-'})
dataset = metafile_metric['Dataset']
_metrics = metafile_metric['Metrics']
if configured_dataset:
for ds in configured_dataset:
if ds in _metrics:
pytorch_metric.update(_metrics[ds])
else:
pytorch_metric.update(_metrics)
task_name = metafile_metric['Task']
if task_name not in using_task:
using_task.add(task_name)
if dataset not in using_dataset:
using_dataset.add(dataset)
dataset_type = '|'.join(list(using_dataset))
task_type = '|'.join(list(using_task))
metric_list = []
for metric in test_yaml_metric_info:
value = '-'
if metric in pytorch_metric:
value = pytorch_metric[metric]
metric_list.append({metric: value})
valid_pytorch_metric = {
k: v
for k, v in pytorch_metric.items() if k in test_yaml_metric_info
}
# get pytorch fps value
fps_info = model_info.get('Metadata').get('inference time (ms/im)')
if fps_info is None:
@ -408,116 +368,27 @@ def get_pytorch_result(model_name: str, meta_info: dict, checkpoint_path: Path,
report_txt_path=report_txt_path,
codebase_name=codebase_name)
logger.info(f'Got {model_config_path} metric: {pytorch_metric}')
return pytorch_metric, dataset_type
logger.info(f'Got {model_config_path} metric: {valid_pytorch_metric}')
dataset_info = dict(dataset=dataset_type, task=task_type)
return valid_pytorch_metric, dataset_info
def get_info_from_log_file(info_type: str, log_path: Path,
yaml_metric_key: str, logger: logging.Logger):
"""Get fps and metric result from log file.
def parse_metric_json(json_file: str) -> dict:
"""Parse metrics result from output json file.
Args:
info_type (str): Get which type of info: 'FPS' or 'metric'.
log_path (Path): Logger path.
yaml_metric_key (str): Name of metric from yaml metric_key.
logger (logger.Logger): Logger handler.
json_file: Input json file.
Returns:
Float: Info value which get from logger file.
dict: metric results
"""
if log_path.exists():
with open(log_path, 'r') as f_log:
lines = f_log.readlines()
logger = get_root_logger()
if not os.path.exists(json_file):
logger.warning(f'File not found {json_file}')
result = {}
else:
logger.warning(f'{log_path} do not exist !!!')
lines = []
if info_type == 'FPS' and len(lines) > 1:
# Get FPS
line_count = 0
fps_sum = 0.00
fps_lines = lines[1:11]
for line in fps_lines:
if 'FPS' not in line:
continue
line_count += 1
fps_sum += float(line.split(' ')[-2])
if fps_sum > 0.00:
info_value = f'{fps_sum / line_count:.2f}'
else:
info_value = 'x'
elif info_type == 'metric' and len(lines) > 1:
# To calculate the final line index
if lines[-1] != '' and lines[-1] != '\n':
line_index = -1
else:
line_index = -2
if yaml_metric_key in ['accuracy_top-1', 'mIoU', 'Eval-PSNR']:
# info in last second line
# mmcls, mmseg, mmedit
metric_line = lines[line_index - 1]
elif yaml_metric_key == 'AP':
# info in last tenth line
# mmpose
metric_line = lines[line_index - 9]
elif yaml_metric_key == 'AR':
# info in last fifth line
# mmpose
metric_line = lines[line_index - 4]
else:
# info in final line
# mmdet
metric_line = lines[line_index]
logger.info(f'Got metric_line = {metric_line}')
metric_str = \
metric_line.replace('\n', '').replace('\r', '').split(' - ')[-1]
logger.info(f'Got metric_str = {metric_str}')
logger.info(f'Got metric_info = {yaml_metric_key}')
if 'OrderedDict' in metric_str:
# mmdet
evaluate_result = eval(metric_str)
if not isinstance(evaluate_result, OrderedDict):
logger.warning(f'Got error metric_dict = {metric_str}')
return 'x'
metric = evaluate_result.get(yaml_metric_key, 0.00) * 100
elif 'accuracy_top' in metric_str:
# mmcls
metric = eval(metric_str.split(': ')[-1])
if metric <= 1:
metric *= 100
elif yaml_metric_key == 'mIoU' and '|' in metric_str:
# mmseg
metric = eval(metric_str.strip().split('|')[2])
if metric <= 1:
metric *= 100
elif yaml_metric_key in ['AP', 'AR']:
# mmpose
metric = eval(metric_str.split(': ')[-1])
elif yaml_metric_key == '0_word_acc_ignore_case' or \
yaml_metric_key == '0_hmean-iou:hmean':
# mmocr
evaluate_result = eval(metric_str)
if not isinstance(evaluate_result, dict):
logger.warning(f'Got error metric_dict = {metric_str}')
return 'x'
metric = evaluate_result.get(yaml_metric_key, 0.00)
if yaml_metric_key == '0_word_acc_ignore_case':
metric *= 100
elif yaml_metric_key in ['Eval-PSNR', 'Eval-SSIM']:
# mmedit
metric = eval(metric_str.split(': ')[-1])
else:
metric = 'x'
info_value = metric
else:
info_value = 'x'
return info_value
logger.info(f'Parse test result from {json_file}')
result = mmcv.load(json_file)
return result
def run_cmd(cmd_lines: List[str], log_path: Path):
@ -567,114 +438,58 @@ def run_cmd(cmd_lines: List[str], log_path: Path):
return return_code
def compare_metric(metric_value: float, metric_name: str, pytorch_metric: dict,
metric_info: dict):
"""Compare metric value with the pytorch metric value and the tolerance.
Args:
metric_value (float): Metric value.
metric_name (str): metric name.
pytorch_metric (dict): Pytorch metric which get from metafile.
metric_info (dict): Metric info from test yaml.
Returns:
Bool: If the test pass or not.
"""
if metric_value == 'x':
return False
metric_pytorch = pytorch_metric.get(str(metric_name))
tolerance_value = metric_info.get(metric_name, {}).get('tolerance', 0.00)
test_pass = metric_value >= (metric_pytorch - tolerance_value)
return test_pass
def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_key: str,
yaml_metric_info_name: str, log_path: Path,
metrics_eval_list: dict, metric_info: dict,
logger: logging.Logger):
def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_info: dict,
json_file: str):
"""Get fps and metric.
Args:
shell_res (int): Backend convert result: 0 is success.
pytorch_metric (dict): Metric info of pytorch metafile.
metric_key (str):Metric info.
yaml_metric_info_name (str): Name of metric info in test yaml.
log_path (Path): Logger path.
metrics_eval_list (dict): Metric list from test yaml.
json_file (Path): Json file of evaluation results.
metric_info (dict): Metric info.
logger (logger.Logger): Logger handler.
Returns:
Float: fps: FPS of the model.
List: metric_list: metric result list.
Bool: test_pass: If the test pass or not.
"""
metric_list = []
fps = '-'
# check if converted successes or not.
if shell_res != 0:
fps = 'x'
metric_value = 'x'
backend_results = {}
else:
# Got fps from log file
fps = get_info_from_log_file('FPS', log_path, metric_key, logger)
# logger.info(f'Got fps = {fps}')
backend_results = parse_metric_json(json_file)
# Got metric from log file
metric_value = get_info_from_log_file('metric', log_path, metric_key,
logger)
logger.info(f'Got metric = {metric_value}')
compare_results = {}
output_result = {}
for metric_name, metric_value in pytorch_metric.items():
metric_key = metric_info[metric_name]['metric_key']
tolerance = metric_info[metric_name]['tolerance']
multi_value = metric_info[metric_name].get('multi_value', 1.0)
compare_flag = False
output_result[metric_name] = 'x'
if metric_key in backend_results:
backend_value = backend_results[metric_key] * multi_value
output_result[metric_name] = backend_value
if backend_value >= metric_value - tolerance:
compare_flag = True
compare_results[metric_name] = compare_flag
if yaml_metric_info_name is None:
logger.error(f'metrics_eval_list: {metrics_eval_list} '
'has not metric name')
assert yaml_metric_info_name is not None
metric_list.append({yaml_metric_info_name: metric_value})
test_pass = compare_metric(metric_value, yaml_metric_info_name,
pytorch_metric, metric_info)
# same eval_name and multi metric output in one test
if yaml_metric_info_name == 'Top 1 Accuracy':
# mmcls
yaml_metric_info_name = 'Top 5 Accuracy'
second_get_metric = True
elif yaml_metric_info_name == 'AP':
# mmpose
yaml_metric_info_name = 'AR'
second_get_metric = True
elif yaml_metric_info_name == 'PSNR':
# mmedit
yaml_metric_info_name = 'SSIM'
second_get_metric = True
if len(compare_results):
test_pass = all(list(compare_results.values()))
else:
second_get_metric = False
if second_get_metric:
metric_key = metric_info.get(yaml_metric_info_name).get('metric_key')
if shell_res != 0:
metric_value = 'x'
else:
metric_value = get_info_from_log_file('metric', log_path,
metric_key, logger)
metric_list.append({yaml_metric_info_name: metric_value})
if test_pass:
test_pass = compare_metric(metric_value, yaml_metric_info_name,
pytorch_metric, metric_info)
return fps, metric_list, test_pass
test_pass = False
return fps, output_result, test_pass
def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
convert_checkpoint_path: str, device_type: str,
eval_name: str, logger: logging.Logger,
metrics_eval_list: dict, pytorch_metric: dict,
logger: logging.Logger, pytorch_metric: dict,
metric_info: dict, backend_name: str,
precision_type: str, metric_useless: set,
convert_result: bool, report_dict: dict,
infer_type: str, log_path: Path, dataset_type: str,
report_txt_path: Path, model_name: str):
precision_type: str, convert_result: bool,
report_dict: dict, infer_type: str, log_path: Path,
dataset_info: dict, report_txt_path: Path,
model_name: str):
"""Get backend fps and metric.
Args:
@ -682,62 +497,46 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path,
model_cfg_path (Path): Model config path.
convert_checkpoint_path (str): Converted checkpoint path.
device_type (str): Device for converting.
eval_name (str): Name of evaluation.
logger (logging.Logger): Logger handler.
metrics_eval_list (dict): Evaluation metric info dict.
pytorch_metric (dict): Pytorch metric info dict get from metafile.
metric_info (dict): Metric info from test yaml.
backend_name (str): Backend name.
precision_type (str): Precision type for evaluation.
metric_useless (set): Useless metric for specific the model.
convert_result (bool): Backend convert result.
report_dict (dict): Backend convert result.
infer_type (str): Infer type.
log_path (Path): Logger save path.
dataset_type (str): Dataset type.
dataset_info (dict): Dataset info.
report_txt_path (Path): report txt save path.
model_name (str): Name of model in test yaml.
"""
json_file = os.path.splitext(str(log_path))[0] + '.json'
cmd_lines = [
'python3 tools/test.py', f'{deploy_cfg_path}',
f'{str(model_cfg_path.absolute())}',
f'{str(model_cfg_path.absolute())}', f'--json-file {json_file}',
f'--model {convert_checkpoint_path}', f'--device {device_type}'
]
codebase_name = get_codebase(str(deploy_cfg_path)).value
if codebase_name != 'mmedit':
eval_name = ' '.join(
list(set([metric_info[k]['eval_name'] for k in pytorch_metric])))
# mmedit dont --metric
cmd_lines += [f'--metrics {eval_name}']
# Test backend
return_code = run_cmd(cmd_lines, log_path)
metric_key = ''
metric_name = ''
task_name = ''
for key, value in metric_info.items():
if value.get('eval_name', '') == eval_name:
metric_name = key
metric_key = value.get('metric_key', '')
task_name = value.get('task_name', '')
break
logger.info(f'Got metric_name = {metric_name}')
logger.info(f'Got metric_key = {metric_key}')
fps, metric_list, test_pass = \
get_fps_metric(return_code, pytorch_metric, metric_key, metric_name,
log_path, metrics_eval_list, metric_info, logger)
# update useless metric
for metric in metric_useless:
metric_list.append({metric: '-'})
if len(metrics_eval_list) > 1 and codebase_name == 'mmdet':
# one model has more than one task, like Mask R-CNN
for name in pytorch_metric:
if name in metric_useless or name == metric_name:
continue
metric_list.append({name: '-'})
fps, backend_metric, test_pass = get_fps_metric(return_code,
pytorch_metric,
metric_info, json_file)
logger.info(f'test_pass={test_pass}, results{backend_metric}')
metric_list = []
for metric in metric_info:
value = '-'
if metric in backend_metric:
value = backend_metric[metric]
metric_list.append({metric: value})
dataset_type = dataset_info['dataset']
task_name = dataset_info['task']
# update the report
update_report(
@ -877,23 +676,12 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
if backend_test is False and sdk_config is None:
test_type = 'convert'
metric_name_list = [str(metric) for metric in pytorch_metric]
assert len(metric_name_list) > 0
metric_all_list = [str(metric) for metric in metric_info]
metric_useless = set(metric_all_list) - set(metric_name_list)
deploy_cfg_path = Path(pipeline_info.get('deploy_config'))
backend_name = str(get_backend(str(deploy_cfg_path)).name).lower()
# change device_type for special case
if backend_name in ['ncnn', 'openvino']:
if backend_name in ['ncnn', 'openvino', 'onnxruntime']:
device_type = 'cpu'
elif backend_name == 'onnxruntime' and device_type != 'cpu':
import onnxruntime as ort
if ort.get_device() != 'GPU':
device_type = 'cpu'
logger.warning('Device type is forced to cpu '
'since onnxruntime-gpu is not installed')
infer_type = \
'dynamic' if is_dynamic_shape(str(deploy_cfg_path)) else 'static'
@ -946,112 +734,65 @@ def get_backend_result(pipeline_info: dict, model_cfg_path: Path,
convert_checkpoint_path = \
str(backend_output_path.joinpath(backend_file_name))
# load deploy_cfg
deploy_cfg, model_cfg = \
load_config(str(deploy_cfg_path),
str(model_cfg_path.absolute()))
# get dataset type
if 'dataset_type' in model_cfg:
dataset_type = \
str(model_cfg.dataset_type).upper().replace('DATASET', '')
else:
dataset_type = metafile_dataset
# Test the model
if convert_result and test_type == 'precision':
# Get evaluation metric from model config
metrics_eval_list = model_cfg.evaluation.get('metric', [])
if isinstance(metrics_eval_list, str):
# some config is using str only
metrics_eval_list = [metrics_eval_list]
# assert len(metrics_eval_list) > 0
logger.info(f'Got metrics_eval_list = {metrics_eval_list}')
if len(metrics_eval_list) == 0 and codebase_name == 'mmedit':
metrics_eval_list = ['PSNR']
# test the model metric
for metric_name in metrics_eval_list:
if backend_test:
log_path = \
gen_log_path(backend_output_path, 'backend_test.log')
get_backend_fps_metric(
deploy_cfg_path=str(deploy_cfg_path),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=convert_checkpoint_path,
device_type=device_type,
eval_name=metric_name,
logger=logger,
metrics_eval_list=metrics_eval_list,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=backend_name,
precision_type=precision_type,
metric_useless=metric_useless,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_type=dataset_type,
report_txt_path=report_txt_path,
model_name=model_name)
if backend_test:
log_path = \
gen_log_path(backend_output_path, 'backend_test.log')
get_backend_fps_metric(
deploy_cfg_path=str(deploy_cfg_path),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=convert_checkpoint_path,
device_type=device_type,
logger=logger,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=backend_name,
precision_type=precision_type,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_info=metafile_dataset,
report_txt_path=report_txt_path,
model_name=model_name)
if sdk_config is not None:
if sdk_config is not None:
if codebase_name == 'mmcls':
replace_top_in_pipeline_json(backend_output_path, logger)
if codebase_name == 'mmcls':
replace_top_in_pipeline_json(backend_output_path, logger)
log_path = gen_log_path(backend_output_path, 'sdk_test.log')
if backend_name == 'onnxruntime':
# sdk only support onnxruntime of cpu
device_type = 'cpu'
# sdk test
get_backend_fps_metric(
deploy_cfg_path=str(sdk_config),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=str(backend_output_path),
device_type=device_type,
eval_name=metric_name,
logger=logger,
metrics_eval_list=metrics_eval_list,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=f'SDK-{backend_name}',
precision_type=precision_type,
metric_useless=metric_useless,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_type=dataset_type,
report_txt_path=report_txt_path,
model_name=model_name)
log_path = gen_log_path(backend_output_path, 'sdk_test.log')
# sdk test
get_backend_fps_metric(
deploy_cfg_path=str(sdk_config),
model_cfg_path=model_cfg_path,
convert_checkpoint_path=str(backend_output_path),
device_type=device_type,
logger=logger,
pytorch_metric=pytorch_metric,
metric_info=metric_info,
backend_name=f'SDK-{backend_name}',
precision_type=precision_type,
convert_result=convert_result,
report_dict=report_dict,
infer_type=infer_type,
log_path=log_path,
dataset_info=metafile_dataset,
report_txt_path=report_txt_path,
model_name=model_name)
else:
logger.info('Only test convert, saving to report...')
metric_list = []
metric_list = [{metric: '-'} for metric in metric_info]
fps = '-'
task_name = ''
for metric in metric_name_list:
metric_list.append({metric: '-'})
metric_task_name = metric_info.get(metric, {}).get('task_name', '')
if metric_task_name in task_name:
logger.debug('metric_task_name exist, skip for adding it...')
continue
task_name += f'{metric_task_name} | '
if ' | ' == task_name[-3:]:
task_name = task_name[:-3]
test_pass = True if convert_result else False
# update useless metric
for metric in metric_useless:
metric_list.append({metric: '-'})
test_pass = convert_result
if convert_result:
report_checkpoint = convert_checkpoint_path
else:
report_checkpoint = str(checkpoint_path)
report_checkpoint = 'x'
dataset_type = metafile_dataset['dataset']
task_name = metafile_dataset['task']
# update the report
update_report(
report_dict=report_dict,

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from mmcv import DictAction
from mmcv.parallel import MMDataParallel
@ -59,6 +60,11 @@ def parse_args():
type=str,
help='log evaluation results and speed to file',
default=None)
parser.add_argument(
'--json-file',
type=str,
help='log evaluation results to json file',
default='./results.json')
parser.add_argument(
'--speed-test', action='store_true', help='activate speed test')
parser.add_argument(
@ -141,9 +147,19 @@ def main():
else:
outputs = task_processor.single_gpu_test(model, data_loader, args.show,
args.show_dir)
task_processor.evaluate_outputs(model_cfg, outputs, dataset, args.metrics,
args.out, args.metric_options,
args.format_only, args.log2file)
json_dir, _ = os.path.split(args.json_file)
if json_dir:
os.makedirs(json_dir, exist_ok=True)
task_processor.evaluate_outputs(
model_cfg,
outputs,
dataset,
args.metrics,
args.out,
args.metric_options,
args.format_only,
args.log2file,
json_file=args.json_file)
# only effective when the backend requires explicit clean-up (e.g. Ascend)
destroy_model()