# Copyright (c) OpenMMLab. All rights reserved. import hashlib import logging import os import os.path as osp import warnings from argparse import ArgumentParser import requests from mmengine import Config from mmseg.apis import inference_model, init_model, show_result_pyplot from mmseg.utils import get_root_logger # ignore warnings when segmentors inference warnings.filterwarnings('ignore') def download_checkpoint(checkpoint_name, model_name, config_name, collect_dir): """Download checkpoint and check if hash code is true.""" url = f'https://download.openmmlab.com/mmsegmentation/v0.5/{model_name}/{config_name}/{checkpoint_name}' # noqa r = requests.get(url) assert r.status_code != 403, f'{url} Access denied.' with open(osp.join(collect_dir, checkpoint_name), 'wb') as code: code.write(r.content) true_hash_code = osp.splitext(checkpoint_name)[0].split('-')[1] # check hash code with open(osp.join(collect_dir, checkpoint_name), 'rb') as fp: sha256_cal = hashlib.sha256() sha256_cal.update(fp.read()) cur_hash_code = sha256_cal.hexdigest()[:8] assert true_hash_code == cur_hash_code, f'{url} download failed, ' 'incomplete downloaded file or url invalid.' if cur_hash_code != true_hash_code: os.remove(osp.join(collect_dir, checkpoint_name)) def parse_args(): parser = ArgumentParser() parser.add_argument('config', help='test config file path') parser.add_argument('checkpoint_root', help='Checkpoint file root path') parser.add_argument( '-i', '--img', default='demo/demo.png', help='Image file') parser.add_argument('-a', '--aug', action='store_true', help='aug test') parser.add_argument('-m', '--model-name', help='model name to inference') parser.add_argument( '-s', '--show', action='store_true', help='show results') parser.add_argument( '-d', '--device', default='cuda:0', help='Device used for inference') args = parser.parse_args() return args def inference(config_name, checkpoint, args, logger=None): cfg = Config.fromfile(config_name) if args.aug: if 'flip' in cfg.data.test.pipeline[ 1] and 'img_scale' in cfg.data.test.pipeline[1]: cfg.data.test.pipeline[1].img_ratios = [ 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 ] cfg.data.test.pipeline[1].flip = True else: if logger is not None: logger.error(f'{config_name}: unable to start aug test') else: print(f'{config_name}: unable to start aug test', flush=True) model = init_model(cfg, checkpoint, device=args.device) # test a single image result = inference_model(model, args.img) # show the results if args.show: show_result_pyplot(model, args.img, result) return result # Sample test whether the inference code is correct def main(args): config = Config.fromfile(args.config) if not os.path.exists(args.checkpoint_root): os.makedirs(args.checkpoint_root, 0o775) # test single model if args.model_name: if args.model_name in config: model_infos = config[args.model_name] if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: config_name = model_info['config'].strip() print(f'processing: {config_name}', flush=True) checkpoint = osp.join(args.checkpoint_root, model_info['checkpoint'].strip()) try: # build the model from a config file and a checkpoint file inference(config_name, checkpoint, args) except Exception: print(f'{config_name} test failed!') continue return else: raise RuntimeError('model name input error.') # test all model logger = get_root_logger( log_file='benchmark_inference_image.log', log_level=logging.ERROR) for model_name in config: model_infos = config[model_name] if not isinstance(model_infos, list): model_infos = [model_infos] for model_info in model_infos: print('processing: ', model_info['config'], flush=True) config_path = model_info['config'].strip() config_name = osp.splitext(osp.basename(config_path))[0] checkpoint_name = model_info['checkpoint'].strip() checkpoint = osp.join(args.checkpoint_root, checkpoint_name) # ensure checkpoint exists try: if not osp.exists(checkpoint): download_checkpoint(checkpoint_name, model_name, config_name.rstrip('.py'), args.checkpoint_root) except Exception: logger.error(f'{checkpoint_name} download error') continue # test model inference with checkpoint try: # build the model from a config file and a checkpoint file inference(config_path, checkpoint, args, logger) except Exception as e: logger.error(f'{config_path} " : {repr(e)}') if __name__ == '__main__': args = parse_args() main(args)