Merge branch 'xiexinchen/fix_benchmark_script' into 'refactor_dev'
[Refactor] Fix benchmark script See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!67pull/1801/head
commit
b2174812bb
|
@ -7,11 +7,12 @@ import mmcv
|
|||
import numpy as np
|
||||
import torch
|
||||
from mmcv import Config
|
||||
from mmcv.parallel import MMDataParallel
|
||||
from mmcv.runner import load_checkpoint, wrap_fp16_model
|
||||
from mmengine.runner import Runner
|
||||
from mmengine.utils import revert_sync_batchnorm
|
||||
|
||||
from mmseg.datasets import build_dataloader, build_dataset
|
||||
from mmseg.models import build_segmentor
|
||||
from mmseg.registry import MODELS
|
||||
from mmseg.utils import register_all_modules
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -31,7 +32,7 @@ def parse_args():
|
|||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
register_all_modules()
|
||||
cfg = Config.fromfile(args.config)
|
||||
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
||||
if args.work_dir is not None:
|
||||
|
@ -48,32 +49,27 @@ def main():
|
|||
# set cudnn_benchmark
|
||||
torch.backends.cudnn.benchmark = False
|
||||
cfg.model.pretrained = None
|
||||
cfg.data.test.test_mode = True
|
||||
|
||||
benchmark_dict = dict(config=args.config, unit='img / s')
|
||||
overall_fps_list = []
|
||||
for time_index in range(repeat_times):
|
||||
print(f'Run {time_index + 1}:')
|
||||
# build the dataloader
|
||||
# TODO: support multiple images per gpu (only minor changes are needed)
|
||||
dataset = build_dataset(cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=1,
|
||||
workers_per_gpu=cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
data_loader = Runner.build_dataloader(cfg.test_dataloader)
|
||||
|
||||
# build the model and load checkpoint
|
||||
cfg.model.train_cfg = None
|
||||
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
model = MODELS.build(cfg.model)
|
||||
fp16_cfg = cfg.get('fp16', None)
|
||||
if fp16_cfg is not None:
|
||||
wrap_fp16_model(model)
|
||||
if 'checkpoint' in args and osp.exists(args.checkpoint):
|
||||
load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
|
||||
model = MMDataParallel(model, device_ids=[0])
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
else:
|
||||
model = revert_sync_batchnorm(model)
|
||||
|
||||
model.eval()
|
||||
|
||||
|
@ -82,16 +78,19 @@ def main():
|
|||
pure_inf_time = 0
|
||||
total_iters = 200
|
||||
|
||||
# benchmark with 200 image and take the average
|
||||
# benchmark with 200 batches and take the average
|
||||
for i, data in enumerate(data_loader):
|
||||
batch_inputs, data_samples = model.data_preprocessor(data, True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(return_loss=False, rescale=True, **data)
|
||||
model(batch_inputs, data_samples, mode='predict')
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
|
|
Loading…
Reference in New Issue