[Dev] Update benchmark regression tools (#556)
* Update benchmark tools * Fix AccessDeniedError * Support to calculate flops and params with benchmark_valid.py * Fix typopull/602/head
parent
e57b8cb33b
commit
c7c5ab7a04
|
@ -9,6 +9,7 @@ import numpy as np
|
|||
import torch
|
||||
from mmcv import Config
|
||||
from mmcv.parallel import collate, scatter
|
||||
from modelindex.load_model_index import load
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
|
@ -68,6 +69,12 @@ def parse_args():
|
|||
'--inference-time',
|
||||
action='store_true',
|
||||
help='Test inference time by run 10 times for each model.')
|
||||
parser.add_argument(
|
||||
'--flops', action='store_true', help='Get Flops and Params of models')
|
||||
parser.add_argument(
|
||||
'--flops-str',
|
||||
action='store_true',
|
||||
help='Output FLOPs and params counts in a string form.')
|
||||
parser.add_argument(
|
||||
'--device', default='cuda:0', help='Device used for inference')
|
||||
args = parser.parse_args()
|
||||
|
@ -119,15 +126,34 @@ def inference(config_file, checkpoint, classes, args):
|
|||
|
||||
result['model'] = config_file.stem
|
||||
|
||||
if args.flops:
|
||||
from mmcv.cnn.utils import get_model_complexity_info
|
||||
if hasattr(model, 'extract_feat'):
|
||||
model.forward = model.extract_feat
|
||||
flops, params = get_model_complexity_info(
|
||||
model,
|
||||
input_shape=(3, ) + resolution,
|
||||
print_per_layer_stat=False,
|
||||
as_strings=args.flops_str)
|
||||
result['flops'] = flops if args.flops_str else int(flops)
|
||||
result['params'] = params if args.flops_str else int(params)
|
||||
else:
|
||||
result['flops'] = ''
|
||||
result['params'] = ''
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def show_summary(summary_data):
|
||||
def show_summary(summary_data, args):
|
||||
table = Table(title='Validation Benchmark Regression Summary')
|
||||
table.add_column('Model')
|
||||
table.add_column('Validation')
|
||||
table.add_column('Resolution (h, w)')
|
||||
table.add_column('Inference Time (std) (ms/im)')
|
||||
if args.inference_time:
|
||||
table.add_column('Inference Time (std) (ms/im)')
|
||||
if args.flops:
|
||||
table.add_column('Flops', justify='right')
|
||||
table.add_column('Params', justify='right')
|
||||
|
||||
for model_name, summary in summary_data.items():
|
||||
row = [model_name]
|
||||
|
@ -136,10 +162,13 @@ def show_summary(summary_data):
|
|||
row.append(f'[{color}]{valid}[/{color}]')
|
||||
if valid == 'PASS':
|
||||
row.append(str(summary['resolution']))
|
||||
if 'time_mean' in summary:
|
||||
if args.inference_time:
|
||||
time_mean = f"{summary['time_mean']:.2f}"
|
||||
time_std = f"{summary['time_std']:.2f}"
|
||||
row.append(f'{time_mean}\t({time_std})'.expandtabs(8))
|
||||
if args.flops:
|
||||
row.append(str(summary['flops']))
|
||||
row.append(str(summary['params']))
|
||||
table.add_row(*row)
|
||||
|
||||
console.print(table)
|
||||
|
@ -148,11 +177,9 @@ def show_summary(summary_data):
|
|||
# Sample test whether the inference code is correct
|
||||
def main(args):
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = Config.fromfile(model_index_file)
|
||||
models = OrderedDict()
|
||||
for file in model_index.Import:
|
||||
metafile = Config.fromfile(MMCLS_ROOT / file)
|
||||
models.update({model.Name: model for model in metafile.Models})
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
logger = get_root_logger(
|
||||
log_file='benchmark_test_image.log', log_level=logging.INFO)
|
||||
|
@ -172,17 +199,33 @@ def main(args):
|
|||
summary_data = {}
|
||||
for model_name, model_info in models.items():
|
||||
|
||||
config = Path(model_info.Config)
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'{model_name}: {config} not found.'
|
||||
|
||||
logger.info(f'Processing: {model_name}')
|
||||
|
||||
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
||||
dataset = model_info.Results[0]['Dataset']
|
||||
dataset = model_info.results[0].dataset
|
||||
if args.checkpoint_root is not None:
|
||||
root = Path(args.checkpoint_root)
|
||||
checkpoint = root / model_info.Weights[len(http_prefix):]
|
||||
checkpoint = str(checkpoint)
|
||||
root = args.checkpoint_root
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from mmcv.fileio import FileClient
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
checkpoint = file_client.join_path(
|
||||
root, model_info.weights[len(http_prefix):])
|
||||
try:
|
||||
exists = file_client.exists(checkpoint)
|
||||
except AccessDeniedError:
|
||||
exists = False
|
||||
else:
|
||||
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
|
||||
exists = checkpoint.exists()
|
||||
if exists:
|
||||
checkpoint = str(checkpoint)
|
||||
else:
|
||||
print(f'WARNING: {model_name}: {checkpoint} not found.')
|
||||
checkpoint = None
|
||||
else:
|
||||
checkpoint = None
|
||||
|
||||
|
@ -200,7 +243,7 @@ def main(args):
|
|||
if args.show:
|
||||
imshow_infos(args.img, result, wait_time=args.wait_time)
|
||||
|
||||
show_summary(summary_data)
|
||||
show_summary(summary_data, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import mmcv
|
||||
from mmcv import Config
|
||||
from modelindex.load_model_index import load
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
@ -69,15 +69,29 @@ def parse_args():
|
|||
|
||||
def create_test_job_batch(commands, model_info, args, port, script_name):
|
||||
|
||||
fname = model_info.Name
|
||||
fname = model_info.name
|
||||
|
||||
config = Path(model_info.Config)
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'{fname}: {config} not found.'
|
||||
|
||||
http_prefix = 'https://download.openmmlab.com/mmclassification/'
|
||||
checkpoint_root = Path(args.checkpoint_root)
|
||||
checkpoint = checkpoint_root / model_info.Weights[len(http_prefix):]
|
||||
assert checkpoint.exists(), f'{fname}: {checkpoint} not found.'
|
||||
if 's3://' in args.checkpoint_root:
|
||||
from mmcv.fileio import FileClient
|
||||
from petrel_client.common.exception import AccessDeniedError
|
||||
file_client = FileClient.infer_client(uri=args.checkpoint_root)
|
||||
checkpoint = file_client.join_path(
|
||||
args.checkpoint_root, model_info.weights[len(http_prefix):])
|
||||
try:
|
||||
exists = file_client.exists(checkpoint)
|
||||
except AccessDeniedError:
|
||||
exists = False
|
||||
else:
|
||||
checkpoint_root = Path(args.checkpoint_root)
|
||||
checkpoint = checkpoint_root / model_info.weights[len(http_prefix):]
|
||||
exists = checkpoint.exists()
|
||||
if not exists:
|
||||
print(f'WARNING: {fname}: {checkpoint} not found.')
|
||||
return None
|
||||
|
||||
job_name = f'{args.job_name}_{fname}'
|
||||
work_dir = Path(args.work_dir) / fname
|
||||
|
@ -127,11 +141,9 @@ def create_test_job_batch(commands, model_info, args, port, script_name):
|
|||
def test(args):
|
||||
# parse model-index.yml
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = Config.fromfile(model_index_file)
|
||||
models = OrderedDict()
|
||||
for file in model_index.Import:
|
||||
metafile = Config.fromfile(MMCLS_ROOT / file)
|
||||
models.update({model.Name: model for model in metafile.Models})
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
script_name = osp.join('tools', 'test.py')
|
||||
port = args.port
|
||||
|
@ -149,19 +161,21 @@ def test(args):
|
|||
return
|
||||
models = filter_models
|
||||
|
||||
preview_script = ''
|
||||
for model_info in models.values():
|
||||
script_path = create_test_job_batch(commands, model_info, args, port,
|
||||
script_name)
|
||||
preview_script = script_path or preview_script
|
||||
port += 1
|
||||
|
||||
command_str = '\n'.join(commands)
|
||||
|
||||
preview = Table()
|
||||
preview.add_column(str(script_path))
|
||||
preview.add_column(str(preview_script))
|
||||
preview.add_column('Shell command preview')
|
||||
preview.add_row(
|
||||
Syntax.from_path(
|
||||
script_path,
|
||||
preview_script,
|
||||
background_color='default',
|
||||
line_numbers=True,
|
||||
word_wrap=True),
|
||||
|
@ -208,7 +222,7 @@ def save_summary(summary_data, models_map, work_dir):
|
|||
row.extend([''] * 2)
|
||||
|
||||
model_info = models_map[model_name]
|
||||
row.append(model_info.Config)
|
||||
row.append(model_info.config)
|
||||
file.write('| ' + ' | '.join(row) + ' |\n')
|
||||
file.close()
|
||||
print('Summary file saved at ' + str(summary_path))
|
||||
|
@ -253,12 +267,9 @@ def show_summary(summary_data):
|
|||
|
||||
def summary(args):
|
||||
model_index_file = MMCLS_ROOT / 'model-index.yml'
|
||||
model_index = Config.fromfile(model_index_file)
|
||||
models = OrderedDict()
|
||||
|
||||
for file in model_index.Import:
|
||||
metafile = Config.fromfile(MMCLS_ROOT / file)
|
||||
models.update({model.Name: model for model in metafile.Models})
|
||||
model_index = load(str(model_index_file))
|
||||
model_index.build_models_with_collections()
|
||||
models = OrderedDict({model.name: model for model in model_index.models})
|
||||
|
||||
work_dir = Path(args.work_dir)
|
||||
|
||||
|
@ -283,10 +294,11 @@ def summary(args):
|
|||
summary_data[model_name] = {}
|
||||
continue
|
||||
|
||||
results = mmcv.load(result_file)
|
||||
with open(result_file, 'rb') as file:
|
||||
results = pickle.load(file)
|
||||
date = datetime.fromtimestamp(result_file.lstat().st_mtime)
|
||||
|
||||
expect_metrics = model_info.Results[0].Metrics
|
||||
expect_metrics = model_info.results[0].metrics
|
||||
|
||||
# extract metrics
|
||||
summary = {'date': date.strftime('%Y-%m-%d')}
|
||||
|
|
|
@ -7,7 +7,7 @@ from datetime import datetime
|
|||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
from mmcv import Config
|
||||
from modelindex.load_model_index import load
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from rich.table import Table
|
||||
|
@ -70,13 +70,13 @@ def parse_args():
|
|||
|
||||
def create_train_job_batch(commands, model_info, args, port, script_name):
|
||||
|
||||
fname = model_info.Name
|
||||
fname = model_info.name
|
||||
|
||||
assert 'Gpus' in model_info, \
|
||||
assert 'Gpus' in model_info.data, \
|
||||
f"Haven't specify gpu numbers for {fname}"
|
||||
gpus = model_info.Gpus
|
||||
gpus = model_info.data['Gpus']
|
||||
|
||||
config = Path(model_info.Config)
|
||||
config = Path(model_info.config)
|
||||
assert config.exists(), f'"{fname}": {config} not found.'
|
||||
|
||||
job_name = f'{args.job_name}_{fname}'
|
||||
|
@ -125,8 +125,9 @@ def create_train_job_batch(commands, model_info, args, port, script_name):
|
|||
|
||||
|
||||
def train(args):
|
||||
models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml')
|
||||
models = {model.Name: model for model in models_cfg.Models}
|
||||
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
|
||||
models_cfg.build_models_with_collections()
|
||||
models = {model.name: model for model in models_cfg.models}
|
||||
|
||||
script_name = osp.join('tools', 'train.py')
|
||||
port = args.port
|
||||
|
@ -145,7 +146,7 @@ def train(args):
|
|||
models = filter_models
|
||||
|
||||
for model_info in models.values():
|
||||
months = model_info.get('Months', range(1, 13))
|
||||
months = model_info.data.get('Months', range(1, 13))
|
||||
if datetime.now().month in months:
|
||||
script_path = create_train_job_batch(commands, model_info, args,
|
||||
port, script_name)
|
||||
|
@ -210,7 +211,7 @@ def save_summary(summary_data, models_map, work_dir):
|
|||
row.extend([''] * 2)
|
||||
|
||||
model_info = models_map[model_name]
|
||||
row.append(model_info.Config)
|
||||
row.append(model_info.config)
|
||||
row.append(str(summary['log_file'].relative_to(work_dir)))
|
||||
zip_file.write(summary['log_file'])
|
||||
file.write('| ' + ' | '.join(row) + ' |\n')
|
||||
|
@ -258,8 +259,8 @@ def show_summary(summary_data):
|
|||
|
||||
|
||||
def summary(args):
|
||||
models_cfg = Config.fromfile(Path(__file__).parent / 'bench_train.yml')
|
||||
models = {model.Name: model for model in models_cfg.Models}
|
||||
models_cfg = load(str(Path(__file__).parent / 'bench_train.yml'))
|
||||
models = {model.name: model for model in models_cfg.models}
|
||||
|
||||
work_dir = Path(args.work_dir)
|
||||
dir_map = {p.name: p for p in work_dir.iterdir() if p.is_dir()}
|
||||
|
@ -300,7 +301,7 @@ def summary(args):
|
|||
if len(val_logs) == 0:
|
||||
continue
|
||||
|
||||
expect_metrics = model_info.Results[0].Metrics
|
||||
expect_metrics = model_info.results[0].metrics
|
||||
|
||||
# extract metrics
|
||||
summary = {'log_file': log_file}
|
||||
|
|
|
@ -14,7 +14,7 @@ line_length = 79
|
|||
multi_line_output = 0
|
||||
known_standard_library = pkg_resources,setuptools
|
||||
known_first_party = mmcls
|
||||
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,torch,torchvision,ts
|
||||
known_third_party = PIL,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,torch,torchvision,ts
|
||||
no_lines_before = STDLIB,LOCALFOLDER
|
||||
default_section = THIRDPARTY
|
||||
|
||||
|
|
Loading…
Reference in New Issue