[Feature] benchmark can add work_dir and repeat times, test.py now has default work-dir (#1126)

* [Feature] benchmark can add work_dir and repeat times

* change the parameter's name

* change the name of the log file

* add skp road

* add default work dir

* make it optional

* Update tools/benchmark.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update tools/benchmark.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* fix typo

* modify json name

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
Rockey 2022-01-18 16:38:31 +08:00 committed by GitHub
parent b997a13e28
commit 7512f05990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 92 additions and 44 deletions

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import os.path as osp
import time import time
import mmcv
import numpy as np
import torch import torch
from mmcv import Config from mmcv import Config
from mmcv.parallel import MMDataParallel from mmcv.parallel import MMDataParallel
@ -17,6 +20,11 @@ def parse_args():
parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument( parser.add_argument(
'--log-interval', type=int, default=50, help='interval of logging') '--log-interval', type=int, default=50, help='interval of logging')
parser.add_argument(
'--work-dir',
help=('if specified, the results will be dumped '
'into the directory as json'))
parser.add_argument('--repeat-times', type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
return args return args
@ -25,61 +33,87 @@ def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.work_dir is not None:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
else:
# use config filename as default work_dir if cfg.work_dir is None
work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(work_dir))
json_file = osp.join(work_dir, f'fps_{timestamp}.json')
repeat_times = args.repeat_times
# set cudnn_benchmark # set cudnn_benchmark
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
cfg.model.pretrained = None cfg.model.pretrained = None
cfg.data.test.test_mode = True cfg.data.test.test_mode = True
# build the dataloader benchmark_dict = dict(config=args.config, unit='img / s')
# TODO: support multiple images per gpu (only minor changes are needed) overall_fps_list = []
dataset = build_dataset(cfg.data.test) for time_index in range(repeat_times):
data_loader = build_dataloader( print(f'Run {time_index + 1}:')
dataset, # build the dataloader
samples_per_gpu=1, # TODO: support multiple images per gpu (only minor changes are needed)
workers_per_gpu=cfg.data.workers_per_gpu, dataset = build_dataset(cfg.data.test)
dist=False, data_loader = build_dataloader(
shuffle=False) dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
# build the model and load checkpoint # build the model and load checkpoint
cfg.model.train_cfg = None cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None) fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None: if fp16_cfg is not None:
wrap_fp16_model(model) wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu') if 'checkpoint' in args and osp.exists(args.checkpoint):
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0]) model = MMDataParallel(model, device_ids=[0])
model.eval() model.eval()
# the first several iterations may be very slow so skip them # the first several iterations may be very slow so skip them
num_warmup = 5 num_warmup = 5
pure_inf_time = 0 pure_inf_time = 0
total_iters = 200 total_iters = 200
# benchmark with 200 image and take the average # benchmark with 200 image and take the average
for i, data in enumerate(data_loader): for i, data in enumerate(data_loader):
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.perf_counter() start_time = time.perf_counter()
with torch.no_grad(): with torch.no_grad():
model(return_loss=False, rescale=True, **data) model(return_loss=False, rescale=True, **data)
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time elapsed = time.perf_counter() - start_time
if i >= num_warmup: if i >= num_warmup:
pure_inf_time += elapsed pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0: if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], '
f'fps: {fps:.2f} img / s')
if (i + 1) == total_iters:
fps = (i + 1 - num_warmup) / pure_inf_time fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], ' print(f'Overall fps: {fps:.2f} img / s\n')
f'fps: {fps:.2f} img / s') benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
overall_fps_list.append(fps)
if (i + 1) == total_iters: break
fps = (i + 1 - num_warmup) / pure_inf_time benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
print(f'Overall fps: {fps:.2f} img / s') benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
break print(f'Average fps of {repeat_times} evaluations: '
f'{benchmark_dict["average_fps"]}')
print(f'The variance of {repeat_times} evaluations: '
f'{benchmark_dict["fps_variance"]}')
mmcv.dump(benchmark_dict, json_file, indent=4)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -109,7 +109,6 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
assert args.out or args.eval or args.format_only or args.show \ assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \ or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the ' ('Please specify at least one operation (save/eval/format/show the '
@ -149,7 +148,23 @@ def main():
if args.work_dir is not None and rank == 0: if args.work_dir is not None and rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
json_file = osp.join(args.work_dir, f'eval_{timestamp}.json') if args.aug_test:
json_file = osp.join(args.work_dir,
f'eval_multi_scale_{timestamp}.json')
else:
json_file = osp.join(args.work_dir,
f'eval_single_scale_{timestamp}.json')
elif rank == 0:
work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.aug_test:
json_file = osp.join(work_dir,
f'eval_multi_scale_{timestamp}.json')
else:
json_file = osp.join(work_dir,
f'eval_single_scale_{timestamp}.json')
# build the dataloader # build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed) # TODO: support multiple images per gpu (only minor changes are needed)
@ -248,8 +263,7 @@ def main():
eval_kwargs.update(metric=args.eval) eval_kwargs.update(metric=args.eval)
metric = dataset.evaluate(results, **eval_kwargs) metric = dataset.evaluate(results, **eval_kwargs)
metric_dict = dict(config=args.config, metric=metric) metric_dict = dict(config=args.config, metric=metric)
if args.work_dir is not None and rank == 0: mmcv.dump(metric_dict, json_file, indent=4)
mmcv.dump(metric_dict, json_file, indent=4)
if tmpdir is not None and eval_on_format_results: if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation # remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)