339 lines
13 KiB
Python
339 lines
13 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import os
|
|
import os.path as osp
|
|
import shutil
|
|
import warnings
|
|
from typing import Any, Iterable
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.parallel import MMDataParallel
|
|
from mmcv.runner import get_dist_info
|
|
from mmcv.utils import DictAction
|
|
|
|
from mmseg.apis import single_gpu_test
|
|
from mmseg.datasets import build_dataloader, build_dataset
|
|
from mmseg.models.segmentors.base import BaseSegmentor
|
|
from mmseg.ops import resize
|
|
|
|
|
|
class ONNXRuntimeSegmentor(BaseSegmentor):
|
|
|
|
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
|
|
super(ONNXRuntimeSegmentor, self).__init__()
|
|
import onnxruntime as ort
|
|
|
|
# get the custom op path
|
|
ort_custom_op_path = ''
|
|
try:
|
|
from mmcv.ops import get_onnxruntime_op_path
|
|
ort_custom_op_path = get_onnxruntime_op_path()
|
|
except (ImportError, ModuleNotFoundError):
|
|
warnings.warn('If input model has custom op from mmcv, \
|
|
you may have to build mmcv with ONNXRuntime from source.')
|
|
session_options = ort.SessionOptions()
|
|
# register custom op for onnxruntime
|
|
if osp.exists(ort_custom_op_path):
|
|
session_options.register_custom_ops_library(ort_custom_op_path)
|
|
sess = ort.InferenceSession(onnx_file, session_options)
|
|
providers = ['CPUExecutionProvider']
|
|
options = [{}]
|
|
is_cuda_available = ort.get_device() == 'GPU'
|
|
if is_cuda_available:
|
|
providers.insert(0, 'CUDAExecutionProvider')
|
|
options.insert(0, {'device_id': device_id})
|
|
|
|
sess.set_providers(providers, options)
|
|
|
|
self.sess = sess
|
|
self.device_id = device_id
|
|
self.io_binding = sess.io_binding()
|
|
self.output_names = [_.name for _ in sess.get_outputs()]
|
|
for name in self.output_names:
|
|
self.io_binding.bind_output(name)
|
|
self.cfg = cfg
|
|
self.test_mode = cfg.model.test_cfg.mode
|
|
self.is_cuda_available = is_cuda_available
|
|
|
|
def extract_feat(self, imgs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def encode_decode(self, img, img_metas):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def forward_train(self, imgs, img_metas, **kwargs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
|
**kwargs) -> list:
|
|
if not self.is_cuda_available:
|
|
img = img.detach().cpu()
|
|
elif self.device_id >= 0:
|
|
img = img.cuda(self.device_id)
|
|
device_type = img.device.type
|
|
self.io_binding.bind_input(
|
|
name='input',
|
|
device_type=device_type,
|
|
device_id=self.device_id,
|
|
element_type=np.float32,
|
|
shape=img.shape,
|
|
buffer_ptr=img.data_ptr())
|
|
self.sess.run_with_iobinding(self.io_binding)
|
|
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
|
# whole might support dynamic reshape
|
|
ori_shape = img_meta[0]['ori_shape']
|
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
|
and ori_shape[1] == seg_pred.shape[-1]):
|
|
seg_pred = torch.from_numpy(seg_pred).float()
|
|
seg_pred = resize(
|
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
|
seg_pred = seg_pred[0]
|
|
seg_pred = list(seg_pred)
|
|
return seg_pred
|
|
|
|
def aug_test(self, imgs, img_metas, **kwargs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
|
|
class TensorRTSegmentor(BaseSegmentor):
|
|
|
|
def __init__(self, trt_file: str, cfg: Any, device_id: int):
|
|
super(TensorRTSegmentor, self).__init__()
|
|
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
|
|
try:
|
|
load_tensorrt_plugin()
|
|
except (ImportError, ModuleNotFoundError):
|
|
warnings.warn('If input model has custom op from mmcv, \
|
|
you may have to build mmcv with TensorRT from source.')
|
|
model = TRTWraper(
|
|
trt_file, input_names=['input'], output_names=['output'])
|
|
|
|
self.model = model
|
|
self.device_id = device_id
|
|
self.cfg = cfg
|
|
self.test_mode = cfg.model.test_cfg.mode
|
|
|
|
def extract_feat(self, imgs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def encode_decode(self, img, img_metas):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def forward_train(self, imgs, img_metas, **kwargs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
|
|
**kwargs) -> list:
|
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
|
seg_pred = self.model({'input': img})['output']
|
|
seg_pred = seg_pred.detach().cpu().numpy()
|
|
# whole might support dynamic reshape
|
|
ori_shape = img_meta[0]['ori_shape']
|
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
|
and ori_shape[1] == seg_pred.shape[-1]):
|
|
seg_pred = torch.from_numpy(seg_pred).float()
|
|
seg_pred = resize(
|
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
|
seg_pred = seg_pred[0]
|
|
seg_pred = list(seg_pred)
|
|
return seg_pred
|
|
|
|
def aug_test(self, imgs, img_metas, **kwargs):
|
|
raise NotImplementedError('This method is not implemented.')
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description='mmseg backend test (and eval)')
|
|
parser.add_argument('config', help='test config file path')
|
|
parser.add_argument('model', help='Input model file')
|
|
parser.add_argument(
|
|
'--backend',
|
|
help='Backend of the model.',
|
|
choices=['onnxruntime', 'tensorrt'])
|
|
parser.add_argument('--out', help='output result file in pickle format')
|
|
parser.add_argument(
|
|
'--format-only',
|
|
action='store_true',
|
|
help='Format the output results without perform evaluation. It is'
|
|
'useful when you want to format the result to a specific format and '
|
|
'submit it to the test server')
|
|
parser.add_argument(
|
|
'--eval',
|
|
type=str,
|
|
nargs='+',
|
|
help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
|
|
' for generic datasets, and "cityscapes" for Cityscapes')
|
|
parser.add_argument('--show', action='store_true', help='show results')
|
|
parser.add_argument(
|
|
'--show-dir', help='directory where painted images will be saved')
|
|
parser.add_argument(
|
|
'--options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help="--options is deprecated in favor of --cfg_options' and it will "
|
|
'not be supported in version v0.22.0. Override some settings in the '
|
|
'used config, the key-value pair in xxx=yyy format will be merged '
|
|
'into config file. If the value to be overwritten is a list, it '
|
|
'should be like key="[a,b]" or key=a,b It also allows nested '
|
|
'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation '
|
|
'marks are necessary and that no white space is allowed.')
|
|
parser.add_argument(
|
|
'--cfg-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='override some settings in the used config, the key-value pair '
|
|
'in xxx=yyy format will be merged into config file. If the value to '
|
|
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
|
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
|
'Note that the quotation marks are necessary and that no white space '
|
|
'is allowed.')
|
|
parser.add_argument(
|
|
'--eval-options',
|
|
nargs='+',
|
|
action=DictAction,
|
|
help='custom options for evaluation')
|
|
parser.add_argument(
|
|
'--opacity',
|
|
type=float,
|
|
default=0.5,
|
|
help='Opacity of painted segmentation map. In (0, 1] range.')
|
|
parser.add_argument('--local_rank', type=int, default=0)
|
|
args = parser.parse_args()
|
|
if 'LOCAL_RANK' not in os.environ:
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
|
|
|
if args.options and args.cfg_options:
|
|
raise ValueError(
|
|
'--options and --cfg-options cannot be both '
|
|
'specified, --options is deprecated in favor of --cfg-options. '
|
|
'--options will not be supported in version v0.22.0.')
|
|
if args.options:
|
|
warnings.warn('--options is deprecated in favor of --cfg-options. '
|
|
'--options will not be supported in version v0.22.0.')
|
|
args.cfg_options = args.options
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
assert args.out or args.eval or args.format_only or args.show \
|
|
or args.show_dir, \
|
|
('Please specify at least one operation (save/eval/format/show the '
|
|
'results / save the results) with the argument "--out", "--eval"'
|
|
', "--format-only", "--show" or "--show-dir"')
|
|
|
|
if args.eval and args.format_only:
|
|
raise ValueError('--eval and --format_only cannot be both specified')
|
|
|
|
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
|
raise ValueError('The output file must be a pkl file.')
|
|
|
|
cfg = mmcv.Config.fromfile(args.config)
|
|
if args.cfg_options is not None:
|
|
cfg.merge_from_dict(args.cfg_options)
|
|
cfg.model.pretrained = None
|
|
cfg.data.test.test_mode = True
|
|
|
|
# init distributed env first, since logger depends on the dist info.
|
|
distributed = False
|
|
|
|
# build the dataloader
|
|
# TODO: support multiple images per gpu (only minor changes are needed)
|
|
dataset = build_dataset(cfg.data.test)
|
|
data_loader = build_dataloader(
|
|
dataset,
|
|
samples_per_gpu=1,
|
|
workers_per_gpu=cfg.data.workers_per_gpu,
|
|
dist=distributed,
|
|
shuffle=False)
|
|
|
|
# load onnx config and meta
|
|
cfg.model.train_cfg = None
|
|
|
|
if args.backend == 'onnxruntime':
|
|
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
|
|
elif args.backend == 'tensorrt':
|
|
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
|
|
|
|
model.CLASSES = dataset.CLASSES
|
|
model.PALETTE = dataset.PALETTE
|
|
|
|
# clean gpu memory when starting a new evaluation.
|
|
torch.cuda.empty_cache()
|
|
eval_kwargs = {} if args.eval_options is None else args.eval_options
|
|
|
|
# Deprecated
|
|
efficient_test = eval_kwargs.get('efficient_test', False)
|
|
if efficient_test:
|
|
warnings.warn(
|
|
'``efficient_test=True`` does not have effect in tools/test.py, '
|
|
'the evaluation and format results are CPU memory efficient by '
|
|
'default')
|
|
|
|
eval_on_format_results = (
|
|
args.eval is not None and 'cityscapes' in args.eval)
|
|
if eval_on_format_results:
|
|
assert len(args.eval) == 1, 'eval on format results is not ' \
|
|
'applicable for metrics other than ' \
|
|
'cityscapes'
|
|
if args.format_only or eval_on_format_results:
|
|
if 'imgfile_prefix' in eval_kwargs:
|
|
tmpdir = eval_kwargs['imgfile_prefix']
|
|
else:
|
|
tmpdir = '.format_cityscapes'
|
|
eval_kwargs.setdefault('imgfile_prefix', tmpdir)
|
|
mmcv.mkdir_or_exist(tmpdir)
|
|
else:
|
|
tmpdir = None
|
|
|
|
model = MMDataParallel(model, device_ids=[0])
|
|
results = single_gpu_test(
|
|
model,
|
|
data_loader,
|
|
args.show,
|
|
args.show_dir,
|
|
False,
|
|
args.opacity,
|
|
pre_eval=args.eval is not None and not eval_on_format_results,
|
|
format_only=args.format_only or eval_on_format_results,
|
|
format_args=eval_kwargs)
|
|
|
|
rank, _ = get_dist_info()
|
|
if rank == 0:
|
|
if args.out:
|
|
warnings.warn(
|
|
'The behavior of ``args.out`` has been changed since MMSeg '
|
|
'v0.16, the pickled outputs could be seg map as type of '
|
|
'np.array, pre-eval results or file paths for '
|
|
'``dataset.format_results()``.')
|
|
print(f'\nwriting results to {args.out}')
|
|
mmcv.dump(results, args.out)
|
|
if args.eval:
|
|
dataset.evaluate(results, args.eval, **eval_kwargs)
|
|
if tmpdir is not None and eval_on_format_results:
|
|
# remove tmp dir when cityscapes evaluation
|
|
shutil.rmtree(tmpdir)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
# Following strings of text style are from colorama package
|
|
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
|
|
red_text, blue_text = '\x1b[31m', '\x1b[34m'
|
|
white_background = '\x1b[107m'
|
|
|
|
msg = white_background + bright_style + red_text
|
|
msg += 'DeprecationWarning: This tool will be deprecated in future. '
|
|
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
|
|
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
|
|
msg += reset_style
|
|
warnings.warn(msg)
|