mmsegmentation/.dev/benchmark_inference.py
sennnnn d3dc4f9583 [Enhancement] Add Dev tools to boost develop (#798)
* Modify default work dir when training.

* Refactor gather_models.py.

* Add train and test matching list.

* Regression benchmark list.

* lower readme name to upper readme name.

* Add url check tool and model inference test tool.

* Modify tool name.

* Support duplicate mode of log json url check.

* Add regression benchmark evaluation automatic tool.

* Add train script generator.

* Only Support script running.

* Add evaluation results gather.

* Add exec Authority.

* Automatically make checkpoint root folder.

* Modify gather results save path.

* Coarse-grained train results gather tool.

* Complete benchmark train script.

* Make some little modifications.

* Fix checkpoint urls.

* Fix unet checkpoint urls.

* Fix fast scnn & fcn checkpoint url.

* Fix fast scnn checkpoint urls.

* Fix fast scnn url.

* Add differential results calculation.

* Add differential results of regression benchmark train results.

* Add an extra argument to select model.

* Update nonlocal_net & hrnet checkpoint url.

* Fix checkpoint url of hrnet and Fix some tta evaluation results and modify gather models tool.

* Modify fast scnn checkpoint url.

* Resolve new comments.

* Fix url check status code bug.

* Resolve some comments.

* Modify train scripts generator.

* Modify work_dir of regression benchmark results.

* model gather tool modification.
2021-09-02 09:44:51 -07:00

150 lines
5.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import hashlib
import logging
import os
import os.path as osp
import warnings
from argparse import ArgumentParser
import requests
from mmcv import Config
from mmseg.apis import inference_segmentor, init_segmentor, show_result_pyplot
from mmseg.utils import get_root_logger
# ignore warnings when segmentors inference
warnings.filterwarnings('ignore')
def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir):
"""Download checkpoint and check if hash code is true."""
url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa
r = requests.get(url)
assert r.status_code != 403, f'{url} Access denied.'
with open(osp.join(collect_dir, checkpoint_name), 'wb') as code:
code.write(r.content)
true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1]
# check hash code
with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp:
sha256_cal = hashlib.sha256()
sha256_cal.update(fp.read())
cur_hash_code = sha256_cal.hexdigest()[:8]
assert true_hash_code == cur_hash_code, f'{url} download failed, '
'incomplete downloaded file or url invalid.'
if cur_hash_code != true_hash_code:
os.remove(osp.join(collect_dir, checkpoint_name))
def parse_args():
parser = ArgumentParser()
parser.add_argument('config', help='test config file path')
parser.add_argument('checkpoint_root', help='Checkpoint file root path')
parser.add_argument(
'-i', '--img', default='demo/demo.png', help='Image file')
parser.add_argument('-a', '--aug', action='store_true', help='aug test')
parser.add_argument('-m', '--model-name', help='model name to inference')
parser.add_argument(
'-s', '--show', action='store_true', help='show results')
parser.add_argument(
'-d', '--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
return args
def inference_model(config_name, checkpoint, args, logger=None):
cfg = Config.fromfile(config_name)
if args.aug:
if 'flip' in cfg.data.test.pipeline[
1] and 'img_scale' in cfg.data.test.pipeline[1]:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75
]
cfg.data.test.pipeline[1].flip = True
else:
if logger is not None:
logger.error(f'{config_name}: unable to start aug test')
else:
print(f'{config_name}: unable to start aug test', flush=True)
model = init_segmentor(cfg, checkpoint, device=args.device)
# test a single image
result = inference_segmentor(model, args.img)
# show the results
if args.show:
show_result_pyplot(model, args.img, result)
return result
# Sample test whether the inference code is correct
def main(args):
config = Config.fromfile(args.config)
if not os.path.exists(args.checkpoint_root):
os.makedirs(args.checkpoint_root, 0o775)
# test single model
if args.model_name:
if args.model_name in config:
model_infos = config[args.model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
config_name = model_info['config'].strip()
print(f'processing: {config_name}', flush=True)
checkpoint = osp.join(args.checkpoint_root,
model_info['checkpoint'].strip())
try:
# build the model from a config file and a checkpoint file
inference_model(config_name, checkpoint, args)
except Exception:
print(f'{config_name} test failed!')
continue
return
else:
raise RuntimeError('model name input error.')
# test all model
logger = get_root_logger(
log_file='benchmark_inference_image.log', log_level=logging.ERROR)
for model_name in config:
model_infos = config[model_name]
if not isinstance(model_infos, list):
model_infos = [model_infos]
for model_info in model_infos:
print('processing: ', model_info['config'], flush=True)
config_path = model_info['config'].strip()
config_name = osp.splitext(osp.basename(config_path))[0]
checkpoint_name = model_info['checkpoint'].strip()
checkpoint = osp.join(args.checkpoint_root, checkpoint_name)
# ensure checkpoint exists
try:
if not osp.exists(checkpoint):
download_checkpoint(checkpoint_name, model_name,
config_name.rstrip('.py'),
args.checkpoint_root)
except Exception:
logger.error(f'{checkpoint_name} download error')
continue
# test model inference with checkpoint
try:
# build the model from a config file and a checkpoint file
inference_model(config_path, checkpoint, args, logger)
except Exception as e:
logger.error(f'{config_path} " : {repr(e)}')
if __name__ == '__main__':
args = parse_args()
main(args)