[Feature] benchmark can add work_dir and repeat times, test.py now has default work-dir (#1126)

* [Feature] benchmark can add work_dir and repeat times

* change the parameter's name

* change the name of the log file

* add skp road

* add default work dir

* make it optional

* Update tools/benchmark.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* Update tools/benchmark.py

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

* fix typo

* modify json name

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
This commit is contained in:
Rockey 2022-01-18 16:38:31 +08:00 committed by GitHub
parent 634fbceab0
commit fa0b1ead3e
2 changed files with 92 additions and 44 deletions

View File

@ -1,7 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
import time
import mmcv
import numpy as np
import torch
from mmcv import Config
from mmcv.parallel import MMDataParallel
@ -17,6 +20,11 @@ def parse_args():
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument(
'--log-interval', type=int, default=50, help='interval of logging')
parser.add_argument(
'--work-dir',
help=('if specified, the results will be dumped '
'into the directory as json'))
parser.add_argument('--repeat-times', type=int, default=1)
args = parser.parse_args()
return args
@ -25,61 +33,87 @@ def main():
args = parse_args()
cfg = Config.fromfile(args.config)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.work_dir is not None:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
json_file = osp.join(args.work_dir, f'fps_{timestamp}.json')
else:
# use config filename as default work_dir if cfg.work_dir is None
work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(work_dir))
json_file = osp.join(work_dir, f'fps_{timestamp}.json')
repeat_times = args.repeat_times
# set cudnn_benchmark
torch.backends.cudnn.benchmark = False
cfg.model.pretrained = None
cfg.data.test.test_mode = True
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
benchmark_dict = dict(config=args.config, unit='img / s')
overall_fps_list = []
for time_index in range(repeat_times):
print(f'Run {time_index + 1}:')
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=cfg.data.workers_per_gpu,
dist=False,
shuffle=False)
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
load_checkpoint(model, args.checkpoint, map_location='cpu')
# build the model and load checkpoint
cfg.model.train_cfg = None
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
if 'checkpoint' in args and osp.exists(args.checkpoint):
load_checkpoint(model, args.checkpoint, map_location='cpu')
model = MMDataParallel(model, device_ids=[0])
model = MMDataParallel(model, device_ids=[0])
model.eval()
model.eval()
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0
total_iters = 200
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0
total_iters = 200
# benchmark with 200 image and take the average
for i, data in enumerate(data_loader):
# benchmark with 200 image and take the average
for i, data in enumerate(data_loader):
torch.cuda.synchronize()
start_time = time.perf_counter()
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(return_loss=False, rescale=True, **data)
with torch.no_grad():
model(return_loss=False, rescale=True, **data)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % args.log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], '
f'fps: {fps:.2f} img / s')
if (i + 1) == total_iters:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Done image [{i + 1:<3}/ {total_iters}], '
f'fps: {fps:.2f} img / s')
if (i + 1) == total_iters:
fps = (i + 1 - num_warmup) / pure_inf_time
print(f'Overall fps: {fps:.2f} img / s')
break
print(f'Overall fps: {fps:.2f} img / s\n')
benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2)
overall_fps_list.append(fps)
break
benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2)
benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4)
print(f'Average fps of {repeat_times} evaluations: '
f'{benchmark_dict["average_fps"]}')
print(f'The variance of {repeat_times} evaluations: '
f'{benchmark_dict["fps_variance"]}')
mmcv.dump(benchmark_dict, json_file, indent=4)
if __name__ == '__main__':

View File

@ -109,7 +109,6 @@ def parse_args():
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
@ -149,7 +148,23 @@ def main():
if args.work_dir is not None and rank == 0:
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
if args.aug_test:
json_file = osp.join(args.work_dir,
f'eval_multi_scale_{timestamp}.json')
else:
json_file = osp.join(args.work_dir,
f'eval_single_scale_{timestamp}.json')
elif rank == 0:
work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])
mmcv.mkdir_or_exist(osp.abspath(work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
if args.aug_test:
json_file = osp.join(work_dir,
f'eval_multi_scale_{timestamp}.json')
else:
json_file = osp.join(work_dir,
f'eval_single_scale_{timestamp}.json')
# build the dataloader
# TODO: support multiple images per gpu (only minor changes are needed)
@ -248,8 +263,7 @@ def main():
eval_kwargs.update(metric=args.eval)
metric = dataset.evaluate(results, **eval_kwargs)
metric_dict = dict(config=args.config, metric=metric)
if args.work_dir is not None and rank == 0:
mmcv.dump(metric_dict, json_file, indent=4)
mmcv.dump(metric_dict, json_file, indent=4)
if tmpdir is not None and eval_on_format_results:
# remove tmp dir when cityscapes evaluation
shutil.rmtree(tmpdir)