Merge branch 'xiexinchen/fix_benchmark_script' into 'refactor_dev'
[Refactor] Fix benchmark script See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!67pull/1801/head
commit
b2174812bb
|
@ -7,11 +7,12 @@ import mmcv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcv import Config
|
from mmcv import Config
|
||||||
from mmcv.parallel import MMDataParallel
|
|
||||||
from mmcv.runner import load_checkpoint, wrap_fp16_model
|
from mmcv.runner import load_checkpoint, wrap_fp16_model
|
||||||
|
from mmengine.runner import Runner
|
||||||
|
from mmengine.utils import revert_sync_batchnorm
|
||||||
|
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.registry import MODELS
|
||||||
from mmseg.models import build_segmentor
|
from mmseg.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -31,7 +32,7 @@ def parse_args():
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
register_all_modules()
|
||||||
cfg = Config.fromfile(args.config)
|
cfg = Config.fromfile(args.config)
|
||||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||||
if args.work_dir is not None:
|
if args.work_dir is not None:
|
||||||
|
@ -48,32 +49,27 @@ def main():
|
||||||
# set cudnn_benchmark
|
# set cudnn_benchmark
|
||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
cfg.model.pretrained = None
|
cfg.model.pretrained = None
|
||||||
cfg.data.test.test_mode = True
|
|
||||||
|
|
||||||
benchmark_dict = dict(config=args.config, unit='img / s')
|
benchmark_dict = dict(config=args.config, unit='img / s')
|
||||||
overall_fps_list = []
|
overall_fps_list = []
|
||||||
for time_index in range(repeat_times):
|
for time_index in range(repeat_times):
|
||||||
print(f'Run {time_index + 1}:')
|
print(f'Run {time_index + 1}:')
|
||||||
# build the dataloader
|
# build the dataloader
|
||||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
||||||
dataset = build_dataset(cfg.data.test)
|
|
||||||
data_loader = build_dataloader(
|
|
||||||
dataset,
|
|
||||||
samples_per_gpu=1,
|
|
||||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
|
||||||
dist=False,
|
|
||||||
shuffle=False)
|
|
||||||
|
|
||||||
# build the model and load checkpoint
|
# build the model and load checkpoint
|
||||||
cfg.model.train_cfg = None
|
cfg.model.train_cfg = None
|
||||||
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
model = MODELS.build(cfg.model)
|
||||||
fp16_cfg = cfg.get('fp16', None)
|
fp16_cfg = cfg.get('fp16', None)
|
||||||
if fp16_cfg is not None:
|
if fp16_cfg is not None:
|
||||||
wrap_fp16_model(model)
|
wrap_fp16_model(model)
|
||||||
if 'checkpoint' in args and osp.exists(args.checkpoint):
|
if 'checkpoint' in args and osp.exists(args.checkpoint):
|
||||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||||
|
|
||||||
model = MMDataParallel(model, device_ids=[0])
|
if torch.cuda.is_available():
|
||||||
|
model = model.cuda()
|
||||||
|
else:
|
||||||
|
model = revert_sync_batchnorm(model)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -82,16 +78,19 @@ def main():
|
||||||
pure_inf_time = 0
|
pure_inf_time = 0
|
||||||
total_iters = 200
|
total_iters = 200
|
||||||
|
|
||||||
# benchmark with 200 image and take the average
|
# benchmark with 200 batches and take the average
|
||||||
for i, data in enumerate(data_loader):
|
for i, data in enumerate(data_loader):
|
||||||
|
batch_inputs, data_samples = model.data_preprocessor(data, True)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model(return_loss=False, rescale=True, **data)
|
model(batch_inputs, data_samples, mode='predict')
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
elapsed = time.perf_counter() - start_time
|
elapsed = time.perf_counter() - start_time
|
||||||
|
|
||||||
if i >= num_warmup:
|
if i >= num_warmup:
|
||||||
|
|
Loading…
Reference in New Issue