mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Feature]: Support FCN,DeeepLabV3, DeepLabV3Plus in mmseg with ONNXRuntime and TensorRT (#31)
* fix mask empty result * support fcn exporting to ONNX for ort and trt in whole mode * resolve comments * remove unnecessary code * update prepare_input * rewrite psp_head & aspp_head * test fcn deeplabv3 deeplabv3plus with trt
This commit is contained in:
parent
dcb88e4439
commit
7dbc12d23f
@ -1,2 +1,2 @@
|
|||||||
[settings]
|
[settings]
|
||||||
known_third_party = mmcls,mmcv,mmdet,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch
|
known_third_party = mmcls,mmcv,mmdet,mmseg,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch
|
||||||
|
18
configs/mmseg/base.py
Normal file
18
configs/mmseg/base.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
_base_ = ['../_base_/torch2onnx.py']
|
||||||
|
codebase = 'mmseg'
|
||||||
|
pytorch2onnx = dict(
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'],
|
||||||
|
dynamic_axes={
|
||||||
|
'input': {
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
},
|
||||||
|
'output': {
|
||||||
|
0: 'batch',
|
||||||
|
2: 'height',
|
||||||
|
3: 'width'
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
1
configs/mmseg/onnxruntime.py
Normal file
1
configs/mmseg/onnxruntime.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
_base_ = ['./base.py', '../_base_/backends/onnxruntime.py']
|
7
configs/mmseg/tensorrt.py
Normal file
7
configs/mmseg/tensorrt.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
|
||||||
|
tensorrt_params = dict(model_params=[
|
||||||
|
dict(
|
||||||
|
opt_shape_dict=dict(
|
||||||
|
input=[[1, 3, 512, 512], [1, 3, 1024, 2048], [1, 3, 2048, 2048]]),
|
||||||
|
max_workspace_size=1 << 30)
|
||||||
|
])
|
@ -10,13 +10,13 @@ from mmdeploy.apis.utils import assert_module_exist
|
|||||||
|
|
||||||
|
|
||||||
def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||||
|
assert_module_exist(codebase)
|
||||||
# load model_cfg if necessary
|
# load model_cfg if necessary
|
||||||
if isinstance(model_cfg, str):
|
if isinstance(model_cfg, str):
|
||||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||||
|
|
||||||
if codebase == 'mmcls':
|
if codebase == 'mmcls':
|
||||||
from mmcls.datasets import (build_dataloader, build_dataset)
|
from mmcls.datasets import (build_dataloader, build_dataset)
|
||||||
assert_module_exist(codebase)
|
|
||||||
# build dataset and dataloader
|
# build dataset and dataloader
|
||||||
dataset = build_dataset(model_cfg.data.test)
|
dataset = build_dataset(model_cfg.data.test)
|
||||||
data_loader = build_dataloader(
|
data_loader = build_dataloader(
|
||||||
@ -27,7 +27,6 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
|||||||
round_up=False)
|
round_up=False)
|
||||||
|
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmdet.datasets import (build_dataloader, build_dataset,
|
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||||
replace_ImageToTensor)
|
replace_ImageToTensor)
|
||||||
# in case the test dataset is concatenated
|
# in case the test dataset is concatenated
|
||||||
@ -58,7 +57,17 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
|||||||
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||||
dist=False,
|
dist=False,
|
||||||
shuffle=False)
|
shuffle=False)
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
from mmseg.datasets import build_dataset, build_dataloader
|
||||||
|
model_cfg.data.test.test_mode = True
|
||||||
|
dataset = build_dataset(model_cfg.data.test)
|
||||||
|
samples_per_gpu = 1
|
||||||
|
data_loader = build_dataloader(
|
||||||
|
dataset,
|
||||||
|
samples_per_gpu=samples_per_gpu,
|
||||||
|
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||||
|
dist=False,
|
||||||
|
shuffle=False)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
|
||||||
@ -71,16 +80,18 @@ def single_gpu_test(codebase: str,
|
|||||||
show: bool = False,
|
show: bool = False,
|
||||||
out_dir: Any = None,
|
out_dir: Any = None,
|
||||||
show_score_thr: float = 0.3):
|
show_score_thr: float = 0.3):
|
||||||
if codebase == 'mmcls':
|
|
||||||
assert_module_exist(codebase)
|
assert_module_exist(codebase)
|
||||||
|
|
||||||
|
if codebase == 'mmcls':
|
||||||
from mmcls.apis import single_gpu_test
|
from mmcls.apis import single_gpu_test
|
||||||
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmdet.apis import single_gpu_test
|
from mmdet.apis import single_gpu_test
|
||||||
outputs = single_gpu_test(model, data_loader, show, out_dir,
|
outputs = single_gpu_test(model, data_loader, show, out_dir,
|
||||||
show_score_thr)
|
show_score_thr)
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
from mmseg.apis import single_gpu_test
|
||||||
|
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
return outputs
|
return outputs
|
||||||
@ -138,5 +149,15 @@ def post_process_outputs(outputs,
|
|||||||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||||
|
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
if out:
|
||||||
|
print(f'\nwriting results to {out}')
|
||||||
|
mmcv.dump(outputs, out)
|
||||||
|
kwargs = {} if metric_options is None else metric_options
|
||||||
|
if format_only:
|
||||||
|
dataset.format_results(outputs, **kwargs)
|
||||||
|
if metrics:
|
||||||
|
dataset.evaluate(outputs, metrics, **kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
@ -31,49 +31,50 @@ def init_model(codebase: str,
|
|||||||
model_checkpoint: Optional[str] = None,
|
model_checkpoint: Optional[str] = None,
|
||||||
device: str = 'cuda:0',
|
device: str = 'cuda:0',
|
||||||
cfg_options: Optional[Dict] = None):
|
cfg_options: Optional[Dict] = None):
|
||||||
# mmcls
|
|
||||||
if codebase == 'mmcls':
|
|
||||||
assert_module_exist(codebase)
|
assert_module_exist(codebase)
|
||||||
|
if codebase == 'mmcls':
|
||||||
from mmcls.apis import init_model
|
from mmcls.apis import init_model
|
||||||
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
|
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
|
||||||
|
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmdet.apis import init_detector
|
from mmdet.apis import init_detector
|
||||||
model = init_detector(model_cfg, model_checkpoint, device, cfg_options)
|
model = init_detector(model_cfg, model_checkpoint, device, cfg_options)
|
||||||
|
|
||||||
elif codebase == 'mmseg':
|
elif codebase == 'mmseg':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmseg.apis import init_segmentor
|
from mmseg.apis import init_segmentor
|
||||||
|
from mmdeploy.mmseg.export import convert_syncbatchnorm
|
||||||
model = init_segmentor(model_cfg, model_checkpoint, device)
|
model = init_segmentor(model_cfg, model_checkpoint, device)
|
||||||
|
model = convert_syncbatchnorm(model)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
|
||||||
return model
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
def create_input(codebase: str,
|
def create_input(codebase: str,
|
||||||
model_cfg: Union[str, mmcv.Config],
|
model_cfg: Union[str, mmcv.Config],
|
||||||
imgs: Any,
|
imgs: Any,
|
||||||
device: str = 'cuda:0'):
|
device: str = 'cuda:0'):
|
||||||
|
assert_module_exist(codebase)
|
||||||
if isinstance(model_cfg, str):
|
if isinstance(model_cfg, str):
|
||||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||||
raise TypeError('config must be a filename or Config object, '
|
raise TypeError('config must be a filename or Config object, '
|
||||||
f'but got {type(model_cfg)}')
|
f'but got {type(model_cfg)}')
|
||||||
|
|
||||||
cfg = model_cfg.copy()
|
cfg = model_cfg.copy()
|
||||||
if codebase == 'mmcls':
|
if codebase == 'mmcls':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmdeploy.mmcls.export import create_input
|
from mmdeploy.mmcls.export import create_input
|
||||||
return create_input(cfg, imgs, device)
|
return create_input(cfg, imgs, device)
|
||||||
|
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
from mmdeploy.mmdet.export import create_input
|
from mmdeploy.mmdet.export import create_input
|
||||||
return create_input(cfg, imgs, device)
|
return create_input(cfg, imgs, device)
|
||||||
|
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
from mmdeploy.mmseg.export import create_input
|
||||||
|
return create_input(cfg, imgs, device)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
|
||||||
@ -94,8 +95,8 @@ def init_backend_model(model_files: Sequence[str],
|
|||||||
backend: str,
|
backend: str,
|
||||||
class_names: Sequence[str],
|
class_names: Sequence[str],
|
||||||
device_id: int = 0):
|
device_id: int = 0):
|
||||||
if codebase == 'mmcls':
|
|
||||||
assert_module_exist(codebase)
|
assert_module_exist(codebase)
|
||||||
|
if codebase == 'mmcls':
|
||||||
if backend == 'onnxruntime':
|
if backend == 'onnxruntime':
|
||||||
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
|
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
|
||||||
backend_model = ONNXRuntimeClassifier(
|
backend_model = ONNXRuntimeClassifier(
|
||||||
@ -120,7 +121,6 @@ def init_backend_model(model_files: Sequence[str],
|
|||||||
return backend_model
|
return backend_model
|
||||||
|
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
if backend == 'onnxruntime':
|
if backend == 'onnxruntime':
|
||||||
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
||||||
backend_model = ONNXRuntimeDetector(
|
backend_model = ONNXRuntimeDetector(
|
||||||
@ -137,61 +137,55 @@ def init_backend_model(model_files: Sequence[str],
|
|||||||
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||||
return backend_model
|
return backend_model
|
||||||
|
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
if backend == 'onnxruntime':
|
||||||
|
from mmdeploy.mmseg.export import ONNXRuntimeSegmentor
|
||||||
|
backend_model = ONNXRuntimeSegmentor(
|
||||||
|
model_files[0], class_names=class_names, device_id=device_id)
|
||||||
|
elif backend == 'tensorrt':
|
||||||
|
from mmdeploy.mmseg.export import TensorRTSegmentor
|
||||||
|
backend_model = TensorRTSegmentor(
|
||||||
|
model_files[0], class_names=class_names, device_id=device_id)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||||
|
return backend_model
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
|
||||||
|
|
||||||
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||||
|
assert_module_exist(codebase)
|
||||||
|
|
||||||
model_cfg_str = model_cfg
|
model_cfg_str = model_cfg
|
||||||
|
if isinstance(model_cfg, str):
|
||||||
|
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||||
|
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||||
|
raise TypeError('config must be a filename or Config object, '
|
||||||
|
f'but got {type(model_cfg)}')
|
||||||
|
|
||||||
if codebase == 'mmcls':
|
if codebase == 'mmcls':
|
||||||
assert_module_exist(codebase)
|
|
||||||
if isinstance(model_cfg, str):
|
|
||||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
|
||||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
|
||||||
raise TypeError('config must be a filename or Config object, '
|
|
||||||
f'but got {type(model_cfg)}')
|
|
||||||
|
|
||||||
from mmcls.datasets import DATASETS
|
from mmcls.datasets import DATASETS
|
||||||
module_dict = DATASETS.module_dict
|
elif codebase == 'mmdet':
|
||||||
data_cfg = model_cfg.data
|
|
||||||
|
|
||||||
if 'train' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.train.type]
|
|
||||||
elif 'val' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.val.type]
|
|
||||||
elif 'test' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.test.type]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
|
||||||
|
|
||||||
return module.CLASSES
|
|
||||||
|
|
||||||
if codebase == 'mmdet':
|
|
||||||
assert_module_exist(codebase)
|
|
||||||
if isinstance(model_cfg, str):
|
|
||||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
|
||||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
|
||||||
raise TypeError('config must be a filename or Config object, '
|
|
||||||
f'but got {type(model_cfg)}')
|
|
||||||
|
|
||||||
from mmdet.datasets import DATASETS
|
from mmdet.datasets import DATASETS
|
||||||
module_dict = DATASETS.module_dict
|
elif codebase == 'mmseg':
|
||||||
data_cfg = model_cfg.data
|
from mmseg.datasets import DATASETS
|
||||||
|
|
||||||
if 'train' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.train.type]
|
|
||||||
elif 'val' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.val.type]
|
|
||||||
elif 'test' in data_cfg:
|
|
||||||
module = module_dict[data_cfg.test.type]
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
|
||||||
|
|
||||||
return module.CLASSES
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
|
||||||
|
module_dict = DATASETS.module_dict
|
||||||
|
data_cfg = model_cfg.data
|
||||||
|
|
||||||
|
if 'train' in data_cfg:
|
||||||
|
module = module_dict[data_cfg.train.type]
|
||||||
|
elif 'val' in data_cfg:
|
||||||
|
module = module_dict[data_cfg.val.type]
|
||||||
|
elif 'test' in data_cfg:
|
||||||
|
module = module_dict[data_cfg.test.type]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||||
|
|
||||||
|
return module.CLASSES
|
||||||
|
|
||||||
|
|
||||||
def check_model_outputs(codebase: str,
|
def check_model_outputs(codebase: str,
|
||||||
image: Union[str, np.ndarray],
|
image: Union[str, np.ndarray],
|
||||||
@ -199,10 +193,12 @@ def check_model_outputs(codebase: str,
|
|||||||
model,
|
model,
|
||||||
output_file: str,
|
output_file: str,
|
||||||
backend: str,
|
backend: str,
|
||||||
|
dataset: str = None,
|
||||||
show_result=False):
|
show_result=False):
|
||||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
|
||||||
if codebase == 'mmcls':
|
|
||||||
assert_module_exist(codebase)
|
assert_module_exist(codebase)
|
||||||
|
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||||
|
|
||||||
|
if codebase == 'mmcls':
|
||||||
output_file = None if show_result else output_file
|
output_file = None if show_result else output_file
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
scores = model(**model_inputs, return_loss=False)[0]
|
scores = model(**model_inputs, return_loss=False)[0]
|
||||||
@ -221,7 +217,6 @@ def check_model_outputs(codebase: str,
|
|||||||
out_file=output_file)
|
out_file=output_file)
|
||||||
|
|
||||||
elif codebase == 'mmdet':
|
elif codebase == 'mmdet':
|
||||||
assert_module_exist(codebase)
|
|
||||||
output_file = None if show_result else output_file
|
output_file = None if show_result else output_file
|
||||||
score_thr = 0.3
|
score_thr = 0.3
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -234,5 +229,20 @@ def check_model_outputs(codebase: str,
|
|||||||
win_name=backend,
|
win_name=backend,
|
||||||
out_file=output_file)
|
out_file=output_file)
|
||||||
|
|
||||||
|
elif codebase == 'mmseg':
|
||||||
|
output_file = None if show_result else output_file
|
||||||
|
from mmseg.core.evaluation import get_palette
|
||||||
|
dataset = 'cityscapes' if dataset is None else dataset
|
||||||
|
palette = get_palette(dataset)
|
||||||
|
with torch.no_grad():
|
||||||
|
results = model(**model_inputs, return_loss=False, rescale=True)
|
||||||
|
model.show_result(
|
||||||
|
show_img,
|
||||||
|
results,
|
||||||
|
palette=palette,
|
||||||
|
show=True,
|
||||||
|
win_name=backend,
|
||||||
|
out_file=output_file,
|
||||||
|
opacity=0.5)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||||
|
@ -74,7 +74,8 @@ class DeployBaseDetector(BaseDetector):
|
|||||||
img_h, img_w = img_metas[i]['img_shape'][:2]
|
img_h, img_w = img_metas[i]['img_shape'][:2]
|
||||||
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
|
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
|
||||||
masks = masks[:, :img_h, :img_w]
|
masks = masks[:, :img_h, :img_w]
|
||||||
if rescale and batch_masks.shape[1] > 0:
|
# avoid to resize masks with zero dim
|
||||||
|
if rescale and masks.shape[0] != 0:
|
||||||
masks = masks.astype(np.float32)
|
masks = masks.astype(np.float32)
|
||||||
masks = torch.from_numpy(masks)
|
masks = torch.from_numpy(masks)
|
||||||
masks = torch.nn.functional.interpolate(
|
masks = torch.nn.functional.interpolate(
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
from .export import * # noqa: F401,F403
|
||||||
|
from .models import * # noqa: F401,F403
|
@ -0,0 +1,8 @@
|
|||||||
|
from .model_wrappers import ONNXRuntimeSegmentor, TensorRTSegmentor
|
||||||
|
from .onnx_helper import convert_syncbatchnorm
|
||||||
|
from .prepare_input import create_input
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'create_input', 'ONNXRuntimeSegmentor', 'TensorRTSegmentor',
|
||||||
|
'convert_syncbatchnorm'
|
||||||
|
]
|
117
mmdeploy/mmseg/export/model_wrappers.py
Normal file
117
mmdeploy/mmseg/export/model_wrappers.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import os.path as osp
|
||||||
|
import warnings
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from mmseg.models.segmentors.base import BaseSegmentor
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
|
||||||
|
class DeployBaseSegmentor(BaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, class_names: Sequence[str], device_id: int):
|
||||||
|
super(DeployBaseSegmentor, self).__init__(init_cfg=None)
|
||||||
|
self.CLASSES = class_names
|
||||||
|
self.device_id = device_id
|
||||||
|
self.PALETTE = None
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def encode_decode(self, img, img_metas):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward_train(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def simple_test(self, img, img_meta, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def aug_test(self, imgs, img_metas, **kwargs):
|
||||||
|
raise NotImplementedError('This method is not implemented.')
|
||||||
|
|
||||||
|
def forward(self, img, img_metas, **kwargs):
|
||||||
|
seg_pred = self.forward_test(img, img_metas, **kwargs)
|
||||||
|
# whole mode supports dynamic shape
|
||||||
|
ori_shape = img_metas[0][0]['ori_shape']
|
||||||
|
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||||
|
and ori_shape[1] == seg_pred.shape[-1]):
|
||||||
|
seg_pred = torch.from_numpy(seg_pred).float()
|
||||||
|
seg_pred = resize(
|
||||||
|
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||||
|
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||||
|
# remove unnecessary dim
|
||||||
|
seg_pred = seg_pred.squeeze(1)
|
||||||
|
seg_pred = list(seg_pred)
|
||||||
|
return seg_pred
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXRuntimeSegmentor(DeployBaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, onnx_file: str, class_names: Sequence[str],
|
||||||
|
device_id: int):
|
||||||
|
super(ONNXRuntimeSegmentor, self).__init__(class_names, device_id)
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
from mmdeploy.apis.onnxruntime import get_ops_path
|
||||||
|
|
||||||
|
# get the custom op path
|
||||||
|
ort_custom_op_path = get_ops_path()
|
||||||
|
session_options = ort.SessionOptions()
|
||||||
|
# register custom op for onnxruntime
|
||||||
|
if osp.exists(ort_custom_op_path):
|
||||||
|
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||||
|
sess = ort.InferenceSession(onnx_file, session_options)
|
||||||
|
providers = ['CPUExecutionProvider']
|
||||||
|
options = [{}]
|
||||||
|
is_cuda_available = ort.get_device() == 'GPU'
|
||||||
|
if is_cuda_available:
|
||||||
|
providers.insert(0, 'CUDAExecutionProvider')
|
||||||
|
options.insert(0, {'device_id': device_id})
|
||||||
|
|
||||||
|
sess.set_providers(providers, options)
|
||||||
|
|
||||||
|
self.sess = sess
|
||||||
|
self.io_binding = sess.io_binding()
|
||||||
|
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||||
|
for name in self.output_names:
|
||||||
|
self.io_binding.bind_output(name)
|
||||||
|
|
||||||
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
|
input_data = imgs[0]
|
||||||
|
device_type = input_data.device.type
|
||||||
|
self.io_binding.bind_input(
|
||||||
|
name='input',
|
||||||
|
device_type=device_type,
|
||||||
|
device_id=self.device_id,
|
||||||
|
element_type=np.float32,
|
||||||
|
shape=input_data.shape,
|
||||||
|
buffer_ptr=input_data.data_ptr())
|
||||||
|
self.sess.run_with_iobinding(self.io_binding)
|
||||||
|
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
||||||
|
return seg_pred
|
||||||
|
|
||||||
|
|
||||||
|
class TensorRTSegmentor(DeployBaseSegmentor):
|
||||||
|
|
||||||
|
def __init__(self, trt_file: str, class_names: Sequence[str],
|
||||||
|
device_id: int):
|
||||||
|
super(TensorRTSegmentor, self).__init__(class_names, device_id)
|
||||||
|
|
||||||
|
from mmdeploy.apis.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||||
|
try:
|
||||||
|
load_tensorrt_plugin()
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
warnings.warn('If input model has custom plugins, \
|
||||||
|
you may have to build backend ops with TensorRT')
|
||||||
|
model = TRTWrapper(trt_file)
|
||||||
|
self.model = model
|
||||||
|
self.output_name = self.model.output_names[0]
|
||||||
|
|
||||||
|
def forward_test(self, imgs, img_metas, **kwargs):
|
||||||
|
input_data = imgs[0].contiguous()
|
||||||
|
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||||
|
seg_pred = self.model({'input': input_data})[self.output_name]
|
||||||
|
seg_pred = seg_pred.detach().cpu().numpy()
|
||||||
|
return seg_pred
|
22
mmdeploy/mmseg/export/onnx_helper.py
Normal file
22
mmdeploy/mmseg/export/onnx_helper.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def convert_syncbatchnorm(module):
|
||||||
|
module_output = module
|
||||||
|
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||||
|
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||||
|
module.momentum, module.affine,
|
||||||
|
module.track_running_stats)
|
||||||
|
if module.affine:
|
||||||
|
module_output.weight.data = module.weight.data.clone().detach()
|
||||||
|
module_output.bias.data = module.bias.data.clone().detach()
|
||||||
|
# keep requires_grad unchanged
|
||||||
|
module_output.weight.requires_grad = module.weight.requires_grad
|
||||||
|
module_output.bias.requires_grad = module.bias.requires_grad
|
||||||
|
module_output.running_mean = module.running_mean
|
||||||
|
module_output.running_var = module.running_var
|
||||||
|
module_output.num_batches_tracked = module.num_batches_tracked
|
||||||
|
for name, child in module.named_children():
|
||||||
|
module_output.add_module(name, convert_syncbatchnorm(child))
|
||||||
|
del module
|
||||||
|
return module_output
|
47
mmdeploy/mmseg/export/prepare_input.py
Normal file
47
mmdeploy/mmseg/export/prepare_input.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
from mmcv.parallel import collate, scatter
|
||||||
|
from mmseg.apis.inference import LoadImage
|
||||||
|
from mmseg.datasets.pipelines import Compose
|
||||||
|
|
||||||
|
|
||||||
|
def create_input(model_cfg: Union[str, mmcv.Config],
|
||||||
|
imgs: Any,
|
||||||
|
device: str = 'cuda:0'):
|
||||||
|
if isinstance(model_cfg, str):
|
||||||
|
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||||
|
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||||
|
raise TypeError('config must be a filename or Config object, '
|
||||||
|
f'but got {type(model_cfg)}')
|
||||||
|
cfg = model_cfg.copy()
|
||||||
|
|
||||||
|
if not isinstance(imgs, (list, tuple)):
|
||||||
|
imgs = [imgs]
|
||||||
|
|
||||||
|
if isinstance(imgs[0], np.ndarray):
|
||||||
|
cfg = cfg.copy()
|
||||||
|
# set loading pipeline type
|
||||||
|
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||||
|
|
||||||
|
cfg.data.test.pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||||
|
cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||||
|
|
||||||
|
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||||
|
datas = []
|
||||||
|
for img in imgs:
|
||||||
|
# prepare data
|
||||||
|
data = dict(img=img)
|
||||||
|
# build the data pipeline
|
||||||
|
data = test_pipeline(data)
|
||||||
|
datas.append(data)
|
||||||
|
|
||||||
|
data = collate(datas, samples_per_gpu=len(imgs))
|
||||||
|
|
||||||
|
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
||||||
|
data['img'] = [img.data[0][None, :] for img in data['img']]
|
||||||
|
if device != 'cpu':
|
||||||
|
data = scatter(data, [device])[0]
|
||||||
|
|
||||||
|
return data, data['img']
|
2
mmdeploy/mmseg/models/__init__.py
Normal file
2
mmdeploy/mmseg/models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .decode_heads import * # noqa: F401,F403
|
||||||
|
from .segmentors import * # noqa: F401,F403
|
4
mmdeploy/mmseg/models/decode_heads/__init__.py
Normal file
4
mmdeploy/mmseg/models/decode_heads/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .aspp_head import forward_of_aspp_head
|
||||||
|
from .psp_head import forward_of_ppm
|
||||||
|
|
||||||
|
__all__ = ['forward_of_aspp_head', 'forward_of_ppm']
|
30
mmdeploy/mmseg/models/decode_heads/aspp_head.py
Normal file
30
mmdeploy/mmseg/models/decode_heads/aspp_head.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import torch
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
from mmdeploy.utils import is_dynamic_shape
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmseg.models.decode_heads.ASPPHead.forward')
|
||||||
|
def forward_of_aspp_head(ctx, self, inputs):
|
||||||
|
x = self._transform_inputs(inputs)
|
||||||
|
deploy_cfg = ctx.cfg
|
||||||
|
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||||
|
# get origin input shape as tensor to support onnx dynamic shape
|
||||||
|
size = x.shape[2:]
|
||||||
|
if not is_dynamic_flag:
|
||||||
|
size = [int(val) for val in size]
|
||||||
|
|
||||||
|
aspp_outs = [
|
||||||
|
resize(
|
||||||
|
self.image_pool(x),
|
||||||
|
size=size,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
]
|
||||||
|
aspp_outs.extend(self.aspp_modules(x))
|
||||||
|
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||||
|
output = self.bottleneck(aspp_outs)
|
||||||
|
output = self.cls_seg(output)
|
||||||
|
return output
|
26
mmdeploy/mmseg/models/decode_heads/psp_head.py
Normal file
26
mmdeploy/mmseg/models/decode_heads/psp_head.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
from mmdeploy.utils import is_dynamic_shape
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmseg.models.decode_heads.psp_head.PPM.forward')
|
||||||
|
def forward_of_ppm(ctx, self, x):
|
||||||
|
deploy_cfg = ctx.cfg
|
||||||
|
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||||
|
# get origin input shape as tensor to support onnx dynamic shape
|
||||||
|
size = x.shape[2:]
|
||||||
|
if not is_dynamic_flag:
|
||||||
|
size = [int(val) for val in size]
|
||||||
|
|
||||||
|
ppm_outs = []
|
||||||
|
for ppm in self:
|
||||||
|
ppm_out = ppm(x)
|
||||||
|
upsampled_ppm_out = resize(
|
||||||
|
ppm_out,
|
||||||
|
size=size,
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
ppm_outs.append(upsampled_ppm_out)
|
||||||
|
return ppm_outs
|
4
mmdeploy/mmseg/models/segmentors/__init__.py
Normal file
4
mmdeploy/mmseg/models/segmentors/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .base import forward_of_base_segmentor
|
||||||
|
from .encoder_decoder import simple_test_of_encoder_decoder
|
||||||
|
|
||||||
|
__all__ = ['forward_of_base_segmentor', 'simple_test_of_encoder_decoder']
|
22
mmdeploy/mmseg/models/segmentors/base.py
Normal file
22
mmdeploy/mmseg/models/segmentors/base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
from mmdeploy.utils import is_dynamic_shape
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||||
|
def forward_of_base_segmentor(ctx, self, img, img_metas=None, **kwargs):
|
||||||
|
if img_metas is None:
|
||||||
|
img_metas = {}
|
||||||
|
assert isinstance(img_metas, dict)
|
||||||
|
assert isinstance(img, torch.Tensor)
|
||||||
|
|
||||||
|
deploy_cfg = ctx.cfg
|
||||||
|
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||||
|
# get origin input shape as tensor to support onnx dynamic shape
|
||||||
|
img_shape = img.shape[2:]
|
||||||
|
if not is_dynamic_flag:
|
||||||
|
img_shape = [int(val) for val in img_shape]
|
||||||
|
img_metas['img_shape'] = img_shape
|
||||||
|
return self.simple_test(img, img_metas, **kwargs)
|
21
mmdeploy/mmseg/models/segmentors/encoder_decoder.py
Normal file
21
mmdeploy/mmseg/models/segmentors/encoder_decoder.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import torch.nn.functional as F
|
||||||
|
from mmseg.ops import resize
|
||||||
|
|
||||||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test')
|
||||||
|
def simple_test_of_encoder_decoder(ctx, self, img, img_meta, **kwargs):
|
||||||
|
x = self.extract_feat(img)
|
||||||
|
seg_logit = self._decode_head_forward_test(x, img_meta)
|
||||||
|
seg_logit = resize(
|
||||||
|
input=seg_logit,
|
||||||
|
size=img_meta['img_shape'],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=self.align_corners)
|
||||||
|
seg_logit = F.softmax(seg_logit, dim=1)
|
||||||
|
seg_pred = seg_logit.argmax(dim=1)
|
||||||
|
# our inference backend only support 4D output
|
||||||
|
seg_pred = seg_pred.unsqueeze(1)
|
||||||
|
return seg_pred
|
Loading…
x
Reference in New Issue
Block a user