mmsegmentation/.dev/benchmark_inference.py

150 lines
5.3 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os
import os.path as osp
import warnings
from argparse import ArgumentParser
import requests
from mmcv import Config
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.utils import get_root_logger
# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
"""Download checkpoint and check if hash code is true."""
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
r = requests.get(url)
assert r.status_code != 403, f'{url} Access denied.'
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
code.write(r.content)
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
# check hash code
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
cur_hash_code = sha256_cal.hexdigest()[:8]
assert true_hash_code == cur_hash_code, f'{url} download failed, '
'incomplete downloaded file or url invalid.'
if cur_hash_code != true_hash_code:
os.remove(osp.join(collect_dir, checkpoint_name))
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument(
'-i', '--img', default='demo/demo.png', help='Image file')
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
parser.add_argument('-m', '--model-name', help='model name to inference')
parser.add_argument(
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def inference_model(config_name, checkpoint, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
if 'flip' in cfg.data.test.pipeline[
1] and 'img_scale' in cfg.data.test.pipeline[1]:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
result = inference_segmentor(model, args.img)
# show the results
if args.show:
show_result_pyplot(model, args.img, result)
return result
# Sample test whether the inference code is correct
def main(args):
config = Config.fromfile(args.config)
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root, 0o775)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, args)
except Exception:
print(f'{config_name} test failed!')
continue
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = get_root_logger(
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
for model_name in config:
model_infos = config[model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_path = model_info['config'].strip()
config_name = osp.splitext(osp.basename(config_path))[0]
checkpoint_name = model_info['checkpoint'].strip()
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
# ensure checkpoint exists
try:
if not osp.exists(checkpoint):
download_checkpoint(checkpoint_name, model_name,
config_name.rstrip('.py'),
args.checkpoint_root)
except Exception:
logger.error(f'{checkpoint_name} download error')
continue
# test model inference with checkpoint
try:
# build the model from a config file and a checkpoint file
inference_model(config_path, checkpoint, args, logger)
except Exception as e:
logger.error(f'{config_path} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)