[Dev] Update benchmark regression tools (#556)

* Update benchmark tools

* Fix AccessDeniedError

* Support to calculate flops and params with benchmark_valid.py

* Fix typo
pull/602/head
Ma Zerun 2021-12-14 17:19:32 +08:00 committed by GitHub
parent e57b8cb33b
commit c7c5ab7a04
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 50 deletions

View File

@ -9,6 +9,7 @@ import numpy as np
import torch
from mmcv import Config
from mmcv.parallel import collate, scatter
from modelindex.load_model_index import load
from rich.console import Console
from rich.table import Table
@ -68,6 +69,12 @@ def parse_args():
'--inference-time',
action='store_true',
help='Test inference time by run 10 times for each model.')
parser.add_argument(
'--flops', action='store_true', help='Get Flops and Params of models')
parser.add_argument(
'--flops-str',
action='store_true',
help='Output FLOPs and params counts in a string form.')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
@ -119,15 +126,34 @@ def inference(config_file, checkpoint, classes, args):
result['model'] = config_file.stem
if args.flops:
from mmcv.cnn.utils import get_model_complexity_info
if hasattr(model, 'extract_feat'):
model.forward = model.extract_feat
flops, params = get_model_complexity_info(
model,
input_shape=(3, ) + resolution,
print_per_layer_stat=False,
as_strings=args.flops_str)
result['flops'] = flops if args.flops_str else int(flops)
result['params'] = params if args.flops_str else int(params)
else:
result['flops'] = ''
result['params'] = ''
return result
def show_summary(summary_data):
def show_summary(summary_data, args):
table = Table(title='Validation Benchmark Regression Summary')
table.add_column('Model')
table.add_column('Validation')
table.add_column('Resolution (h, w)')
table.add_column('Inference Time (std) (ms/im)')
if args.inference_time:
table.add_column('Inference Time (std) (ms/im)')
if args.flops:
table.add_column('Flops', justify='right')
table.add_column('Params', justify='right')
for model_name, summary in summary_data.items():
row = [model_name]
@ -136,10 +162,13 @@ def show_summary(summary_data):
row.append(f'[{color}]{valid}[/{color}]')
if valid == 'PASS':
row.append(str(summary['resolution']))
if 'time_mean' in summary:
if args.inference_time:
time_mean = f"{summary['time_mean']:.2f}"
time_std = f"{summary['time_std']:.2f}"
row.append(f'{time_mean}\t({time_std})'.expandtabs(8))
if args.flops:
row.append(str(summary['flops']))
row.append(str(summary['params']))
table.add_row(*row)
console.print(table)
@ -148,11 +177,9 @@ def show_summary(summary_data):
# Sample test whether the inference code is correct
def main(args):
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = Config.fromfile(model_index_file)
models = OrderedDict()
for file in model_index.Import:
metafile = Config.fromfile(MMCLS_ROOT / file)
models.update({model.Name: model for model in metafile.Models})
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
logger = get_root_logger(
log_file='benchmark_test_image.log', log_level=logging.INFO)
@ -172,17 +199,33 @@ def main(args):
summary_data = {}
for model_name, model_info in models.items():
config = Path(model_info.Config)
config = Path(model_info.config)
assert config.exists(), f'{model_name}: {config} not found.'
logger.info(f'Processing: {model_name}')
http_prefix = 'https://download.openmmlab.com/mmclassification/'
dataset = model_info.Results[0]['Dataset']
dataset = model_info.results[0].dataset
if args.checkpoint_root is not None:
root = Path(args.checkpoint_root)
checkpoint = root / model_info.Weights[len(http_prefix):]
checkpoint = str(checkpoint)
root = args.checkpoint_root
if 's3://' in args.checkpoint_root:
from mmcv.fileio import FileClient
from petrel_client.common.exception import AccessDeniedError
file_client = FileClient.infer_client(uri=root)
checkpoint = file_client.join_path(
root, model_info.weights[len(http_prefix):])
try:
exists = file_client.exists(checkpoint)
except AccessDeniedError:
exists = False
else:
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
exists = checkpoint.exists()
if exists:
checkpoint = str(checkpoint)
else:
print(f'WARNING: {model_name}: {checkpoint} not found.')
checkpoint = None
else:
checkpoint = None
@ -200,7 +243,7 @@ def main(args):
if args.show:
imshow_infos(args.img, result, wait_time=args.wait_time)
show_summary(summary_data)
show_summary(summary_data, args)
if __name__ == '__main__':

View File

@ -1,13 +1,13 @@
import argparse
import os
import os.path as osp
import pickle
import re
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
import mmcv
from mmcv import Config
from modelindex.load_model_index import load
from rich.console import Console
from rich.syntax import Syntax
from rich.table import Table
@ -69,15 +69,29 @@ def parse_args():
def create_test_job_batch(commands, model_info, args, port, script_name):
fname = model_info.Name
fname = model_info.name
config = Path(model_info.Config)
config = Path(model_info.config)
assert config.exists(), f'{fname}: {config} not found.'
http_prefix = 'https://download.openmmlab.com/mmclassification/'
checkpoint_root = Path(args.checkpoint_root)
checkpoint = checkpoint_root / model_info.Weights[len(http_prefix):]
assert checkpoint.exists(), f'{fname}: {checkpoint} not found.'
if 's3://' in args.checkpoint_root:
from mmcv.fileio import FileClient
from petrel_client.common.exception import AccessDeniedError
file_client = FileClient.infer_client(uri=args.checkpoint_root)
checkpoint = file_client.join_path(
args.checkpoint_root, model_info.weights[len(http_prefix):])
try:
exists = file_client.exists(checkpoint)
except AccessDeniedError:
exists = False
else:
checkpoint_root = Path(args.checkpoint_root)
checkpoint = checkpoint_root / model_info.weights[len(http_prefix):]
exists = checkpoint.exists()
if not exists:
print(f'WARNING: {fname}: {checkpoint} not found.')
return None
job_name = f'{args.job_name}_{fname}'
work_dir = Path(args.work_dir) / fname
@ -127,11 +141,9 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
def test(args):
# parse model-index.yml
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = Config.fromfile(model_index_file)
models = OrderedDict()
for file in model_index.Import:
metafile = Config.fromfile(MMCLS_ROOT / file)
models.update({model.Name: model for model in metafile.Models})
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
script_name = osp.join('tools', 'test.py')
port = args.port
@ -149,19 +161,21 @@ def test(args):
return
models = filter_models
preview_script = ''
for model_info in models.values():
script_path = create_test_job_batch(commands, model_info, args, port,
script_name)
preview_script = script_path or preview_script
port += 1
command_str = '\n'.join(commands)
preview = Table()
preview.add_column(str(script_path))
preview.add_column(str(preview_script))
preview.add_column('Shell command preview')
preview.add_row(
Syntax.from_path(
script_path,
preview_script,
background_color='default',
line_numbers=True,
word_wrap=True),
@ -208,7 +222,7 @@ def save_summary(summary_data, models_map, work_dir):
row.extend([''] * 2)
model_info = models_map[model_name]
row.append(model_info.Config)
row.append(model_info.config)
file.write('| ' + ' | '.join(row) + ' |\n')
file.close()
print('Summary file saved at ' + str(summary_path))
@ -253,12 +267,9 @@ def show_summary(summary_data):
def summary(args):
model_index_file = MMCLS_ROOT / 'model-index.yml'
model_index = Config.fromfile(model_index_file)
models = OrderedDict()
for file in model_index.Import:
metafile = Config.fromfile(MMCLS_ROOT / file)
models.update({model.Name: model for model in metafile.Models})
model_index = load(str(model_index_file))
model_index.build_models_with_collections()
models = OrderedDict({model.name: model for model in model_index.models})
work_dir = Path(args.work_dir)
@ -283,10 +294,11 @@ def summary(args):
summary_data[model_name] = {}
continue
results = mmcv.load(result_file)
with open(result_file, 'rb') as file:
results = pickle.load(file)
date = datetime.fromtimestamp(result_file.lstat().st_mtime)
expect_metrics = model_info.Results[0].Metrics
expect_metrics = model_info.results[0].metrics
# extract metrics
summary = {'date': date.strftime('%Y-%m-%d')}

View File

@ -7,7 +7,7 @@ from datetime import datetime
from pathlib import Path
from zipfile import ZipFile
from mmcv import Config
from modelindex.load_model_index import load
from rich.console import Console
from rich.syntax import Syntax
from rich.table import Table
@ -70,13 +70,13 @@ def parse_args():
def create_train_job_batch(commands, model_info, args, port, script_name):
fname = model_info.Name
fname = model_info.name
assert 'Gpus' in model_info, \
assert 'Gpus' in model_info.data, \
f"Haven't specify gpu numbers for {fname}"
gpus = model_info.Gpus
gpus = model_info.data['Gpus']
config = Path(model_info.Config)
config = Path(model_info.config)
assert config.exists(), f'"{fname}": {config} not found.'
job_name = f'{args.job_name}_{fname}'
@ -125,8 +125,9 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
def train(args):
models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml')
models = {model.Name: model for model in models_cfg.Models}
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
models_cfg.build_models_with_collections()
models = {model.name: model for model in models_cfg.models}
script_name = osp.join('tools', 'train.py')
port = args.port
@ -145,7 +146,7 @@ def train(args):
models = filter_models
for model_info in models.values():
months = model_info.get('Months', range(1, 13))
months = model_info.data.get('Months', range(1, 13))
if datetime.now().month in months:
script_path = create_train_job_batch(commands, model_info, args,
port, script_name)
@ -210,7 +211,7 @@ def save_summary(summary_data, models_map, work_dir):
row.extend([''] * 2)
model_info = models_map[model_name]
row.append(model_info.Config)
row.append(model_info.config)
row.append(str(summary['log_file'].relative_to(work_dir)))
zip_file.write(summary['log_file'])
file.write('| ' + ' | '.join(row) + ' |\n')
@ -258,8 +259,8 @@ def show_summary(summary_data):
def summary(args):
models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml')
models = {model.Name: model for model in models_cfg.Models}
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
models = {model.name: model for model in models_cfg.models}
work_dir = Path(args.work_dir)
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
@ -300,7 +301,7 @@ def summary(args):
if len(val_logs) == 0:
continue
expect_metrics = model_info.Results[0].Metrics
expect_metrics = model_info.results[0].metrics
# extract metrics
summary = {'log_file': log_file}

View File

@ -14,7 +14,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,torch,torchvision,ts
known_third_party = PIL,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY