mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Feature]: Support FCN,DeeepLabV3, DeepLabV3Plus in mmseg with ONNXRuntime and TensorRT (#31)
* fix mask empty result * support fcn exporting to ONNX for ort and trt in whole mode * resolve comments * remove unnecessary code * update prepare_input * rewrite psp_head & aspp_head * test fcn deeplabv3 deeplabv3plus with trt
This commit is contained in:
parent
dcb88e4439
commit
7dbc12d23f
@ -1,2 +1,2 @@
|
||||
[settings]
|
||||
known_third_party = mmcls,mmcv,mmdet,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch
|
||||
known_third_party = mmcls,mmcv,mmdet,mmseg,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch
|
||||
|
18
configs/mmseg/base.py
Normal file
18
configs/mmseg/base.py
Normal file
@ -0,0 +1,18 @@
|
||||
_base_ = ['../_base_/torch2onnx.py']
|
||||
codebase = 'mmseg'
|
||||
pytorch2onnx = dict(
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes={
|
||||
'input': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
'output': {
|
||||
0: 'batch',
|
||||
2: 'height',
|
||||
3: 'width'
|
||||
},
|
||||
},
|
||||
)
|
1
configs/mmseg/onnxruntime.py
Normal file
1
configs/mmseg/onnxruntime.py
Normal file
@ -0,0 +1 @@
|
||||
_base_ = ['./base.py', '../_base_/backends/onnxruntime.py']
|
7
configs/mmseg/tensorrt.py
Normal file
7
configs/mmseg/tensorrt.py
Normal file
@ -0,0 +1,7 @@
|
||||
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
|
||||
tensorrt_params = dict(model_params=[
|
||||
dict(
|
||||
opt_shape_dict=dict(
|
||||
input=[[1, 3, 512, 512], [1, 3, 1024, 2048], [1, 3, 2048, 2048]]),
|
||||
max_workspace_size=1 << 30)
|
||||
])
|
@ -10,13 +10,13 @@ from mmdeploy.apis.utils import assert_module_exist
|
||||
|
||||
|
||||
def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
assert_module_exist(codebase)
|
||||
# load model_cfg if necessary
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
|
||||
if codebase == 'mmcls':
|
||||
from mmcls.datasets import (build_dataloader, build_dataset)
|
||||
assert_module_exist(codebase)
|
||||
# build dataset and dataloader
|
||||
dataset = build_dataset(model_cfg.data.test)
|
||||
data_loader = build_dataloader(
|
||||
@ -27,7 +27,6 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
round_up=False)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.datasets import (build_dataloader, build_dataset,
|
||||
replace_ImageToTensor)
|
||||
# in case the test dataset is concatenated
|
||||
@ -58,7 +57,17 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
from mmseg.datasets import build_dataset, build_dataloader
|
||||
model_cfg.data.test.test_mode = True
|
||||
dataset = build_dataset(model_cfg.data.test)
|
||||
samples_per_gpu = 1
|
||||
data_loader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=model_cfg.data.workers_per_gpu,
|
||||
dist=False,
|
||||
shuffle=False)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -71,16 +80,18 @@ def single_gpu_test(codebase: str,
|
||||
show: bool = False,
|
||||
out_dir: Any = None,
|
||||
show_score_thr: float = 0.3):
|
||||
assert_module_exist(codebase)
|
||||
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
from mmcls.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir,
|
||||
show_score_thr)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
from mmseg.apis import single_gpu_test
|
||||
outputs = single_gpu_test(model, data_loader, show, out_dir)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
return outputs
|
||||
@ -138,5 +149,15 @@ def post_process_outputs(outputs,
|
||||
eval_kwargs.update(dict(metric=metrics, **kwargs))
|
||||
print(dataset.evaluate(outputs, **eval_kwargs))
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
if out:
|
||||
print(f'\nwriting results to {out}')
|
||||
mmcv.dump(outputs, out)
|
||||
kwargs = {} if metric_options is None else metric_options
|
||||
if format_only:
|
||||
dataset.format_results(outputs, **kwargs)
|
||||
if metrics:
|
||||
dataset.evaluate(outputs, metrics, **kwargs)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
@ -31,49 +31,50 @@ def init_model(codebase: str,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
device: str = 'cuda:0',
|
||||
cfg_options: Optional[Dict] = None):
|
||||
# mmcls
|
||||
assert_module_exist(codebase)
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
from mmcls.apis import init_model
|
||||
model = init_model(model_cfg, model_checkpoint, device, cfg_options)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdet.apis import init_detector
|
||||
model = init_detector(model_cfg, model_checkpoint, device, cfg_options)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
assert_module_exist(codebase)
|
||||
from mmseg.apis import init_segmentor
|
||||
from mmdeploy.mmseg.export import convert_syncbatchnorm
|
||||
model = init_segmentor(model_cfg, model_checkpoint, device)
|
||||
model = convert_syncbatchnorm(model)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
return model
|
||||
return model.eval()
|
||||
|
||||
|
||||
def create_input(codebase: str,
|
||||
model_cfg: Union[str, mmcv.Config],
|
||||
imgs: Any,
|
||||
device: str = 'cuda:0'):
|
||||
assert_module_exist(codebase)
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
cfg = model_cfg.copy()
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
from mmdeploy.mmcls.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
from mmdeploy.mmdet.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
from mmdeploy.mmseg.export import create_input
|
||||
return create_input(cfg, imgs, device)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
@ -94,8 +95,8 @@ def init_backend_model(model_files: Sequence[str],
|
||||
backend: str,
|
||||
class_names: Sequence[str],
|
||||
device_id: int = 0):
|
||||
assert_module_exist(codebase)
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmcls.export import ONNXRuntimeClassifier
|
||||
backend_model = ONNXRuntimeClassifier(
|
||||
@ -120,7 +121,6 @@ def init_backend_model(model_files: Sequence[str],
|
||||
return backend_model
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmdet.export import ONNXRuntimeDetector
|
||||
backend_model = ONNXRuntimeDetector(
|
||||
@ -137,61 +137,55 @@ def init_backend_model(model_files: Sequence[str],
|
||||
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
if backend == 'onnxruntime':
|
||||
from mmdeploy.mmseg.export import ONNXRuntimeSegmentor
|
||||
backend_model = ONNXRuntimeSegmentor(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
elif backend == 'tensorrt':
|
||||
from mmdeploy.mmseg.export import TensorRTSegmentor
|
||||
backend_model = TensorRTSegmentor(
|
||||
model_files[0], class_names=class_names, device_id=device_id)
|
||||
else:
|
||||
raise NotImplementedError(f'Unsupported backend type: {backend}')
|
||||
return backend_model
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
|
||||
def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]):
|
||||
assert_module_exist(codebase)
|
||||
|
||||
model_cfg_str = model_cfg
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
from mmcls.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
|
||||
if codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
from mmdet.datasets import DATASETS
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
from mmseg.datasets import DATASETS
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
||||
module_dict = DATASETS.module_dict
|
||||
data_cfg = model_cfg.data
|
||||
|
||||
if 'train' in data_cfg:
|
||||
module = module_dict[data_cfg.train.type]
|
||||
elif 'val' in data_cfg:
|
||||
module = module_dict[data_cfg.val.type]
|
||||
elif 'test' in data_cfg:
|
||||
module = module_dict[data_cfg.test.type]
|
||||
else:
|
||||
raise RuntimeError(f'No dataset config found in: {model_cfg_str}')
|
||||
|
||||
return module.CLASSES
|
||||
|
||||
|
||||
def check_model_outputs(codebase: str,
|
||||
image: Union[str, np.ndarray],
|
||||
@ -199,10 +193,12 @@ def check_model_outputs(codebase: str,
|
||||
model,
|
||||
output_file: str,
|
||||
backend: str,
|
||||
dataset: str = None,
|
||||
show_result=False):
|
||||
assert_module_exist(codebase)
|
||||
show_img = mmcv.imread(image) if isinstance(image, str) else image
|
||||
|
||||
if codebase == 'mmcls':
|
||||
assert_module_exist(codebase)
|
||||
output_file = None if show_result else output_file
|
||||
with torch.no_grad():
|
||||
scores = model(**model_inputs, return_loss=False)[0]
|
||||
@ -221,7 +217,6 @@ def check_model_outputs(codebase: str,
|
||||
out_file=output_file)
|
||||
|
||||
elif codebase == 'mmdet':
|
||||
assert_module_exist(codebase)
|
||||
output_file = None if show_result else output_file
|
||||
score_thr = 0.3
|
||||
with torch.no_grad():
|
||||
@ -234,5 +229,20 @@ def check_model_outputs(codebase: str,
|
||||
win_name=backend,
|
||||
out_file=output_file)
|
||||
|
||||
elif codebase == 'mmseg':
|
||||
output_file = None if show_result else output_file
|
||||
from mmseg.core.evaluation import get_palette
|
||||
dataset = 'cityscapes' if dataset is None else dataset
|
||||
palette = get_palette(dataset)
|
||||
with torch.no_grad():
|
||||
results = model(**model_inputs, return_loss=False, rescale=True)
|
||||
model.show_result(
|
||||
show_img,
|
||||
results,
|
||||
palette=palette,
|
||||
show=True,
|
||||
win_name=backend,
|
||||
out_file=output_file,
|
||||
opacity=0.5)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown codebase type: {codebase}')
|
||||
|
@ -74,7 +74,8 @@ class DeployBaseDetector(BaseDetector):
|
||||
img_h, img_w = img_metas[i]['img_shape'][:2]
|
||||
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
|
||||
masks = masks[:, :img_h, :img_w]
|
||||
if rescale and batch_masks.shape[1] > 0:
|
||||
# avoid to resize masks with zero dim
|
||||
if rescale and masks.shape[0] != 0:
|
||||
masks = masks.astype(np.float32)
|
||||
masks = torch.from_numpy(masks)
|
||||
masks = torch.nn.functional.interpolate(
|
||||
|
@ -0,0 +1,2 @@
|
||||
from .export import * # noqa: F401,F403
|
||||
from .models import * # noqa: F401,F403
|
@ -0,0 +1,8 @@
|
||||
from .model_wrappers import ONNXRuntimeSegmentor, TensorRTSegmentor
|
||||
from .onnx_helper import convert_syncbatchnorm
|
||||
from .prepare_input import create_input
|
||||
|
||||
__all__ = [
|
||||
'create_input', 'ONNXRuntimeSegmentor', 'TensorRTSegmentor',
|
||||
'convert_syncbatchnorm'
|
||||
]
|
117
mmdeploy/mmseg/export/model_wrappers.py
Normal file
117
mmdeploy/mmseg/export/model_wrappers.py
Normal file
@ -0,0 +1,117 @@
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
|
||||
class DeployBaseSegmentor(BaseSegmentor):
|
||||
|
||||
def __init__(self, class_names: Sequence[str], device_id: int):
|
||||
super(DeployBaseSegmentor, self).__init__(init_cfg=None)
|
||||
self.CLASSES = class_names
|
||||
self.device_id = device_id
|
||||
self.PALETTE = None
|
||||
|
||||
def extract_feat(self, imgs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def simple_test(self, img, img_meta, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
raise NotImplementedError('This method is not implemented.')
|
||||
|
||||
def forward(self, img, img_metas, **kwargs):
|
||||
seg_pred = self.forward_test(img, img_metas, **kwargs)
|
||||
# whole mode supports dynamic shape
|
||||
ori_shape = img_metas[0][0]['ori_shape']
|
||||
if not (ori_shape[0] == seg_pred.shape[-2]
|
||||
and ori_shape[1] == seg_pred.shape[-1]):
|
||||
seg_pred = torch.from_numpy(seg_pred).float()
|
||||
seg_pred = resize(
|
||||
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
|
||||
seg_pred = seg_pred.long().detach().cpu().numpy()
|
||||
# remove unnecessary dim
|
||||
seg_pred = seg_pred.squeeze(1)
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
|
||||
class ONNXRuntimeSegmentor(DeployBaseSegmentor):
|
||||
|
||||
def __init__(self, onnx_file: str, class_names: Sequence[str],
|
||||
device_id: int):
|
||||
super(ONNXRuntimeSegmentor, self).__init__(class_names, device_id)
|
||||
|
||||
import onnxruntime as ort
|
||||
from mmdeploy.apis.onnxruntime import get_ops_path
|
||||
|
||||
# get the custom op path
|
||||
ort_custom_op_path = get_ops_path()
|
||||
session_options = ort.SessionOptions()
|
||||
# register custom op for onnxruntime
|
||||
if osp.exists(ort_custom_op_path):
|
||||
session_options.register_custom_ops_library(ort_custom_op_path)
|
||||
sess = ort.InferenceSession(onnx_file, session_options)
|
||||
providers = ['CPUExecutionProvider']
|
||||
options = [{}]
|
||||
is_cuda_available = ort.get_device() == 'GPU'
|
||||
if is_cuda_available:
|
||||
providers.insert(0, 'CUDAExecutionProvider')
|
||||
options.insert(0, {'device_id': device_id})
|
||||
|
||||
sess.set_providers(providers, options)
|
||||
|
||||
self.sess = sess
|
||||
self.io_binding = sess.io_binding()
|
||||
self.output_names = [_.name for _ in sess.get_outputs()]
|
||||
for name in self.output_names:
|
||||
self.io_binding.bind_output(name)
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
input_data = imgs[0]
|
||||
device_type = input_data.device.type
|
||||
self.io_binding.bind_input(
|
||||
name='input',
|
||||
device_type=device_type,
|
||||
device_id=self.device_id,
|
||||
element_type=np.float32,
|
||||
shape=input_data.shape,
|
||||
buffer_ptr=input_data.data_ptr())
|
||||
self.sess.run_with_iobinding(self.io_binding)
|
||||
seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
|
||||
return seg_pred
|
||||
|
||||
|
||||
class TensorRTSegmentor(DeployBaseSegmentor):
|
||||
|
||||
def __init__(self, trt_file: str, class_names: Sequence[str],
|
||||
device_id: int):
|
||||
super(TensorRTSegmentor, self).__init__(class_names, device_id)
|
||||
|
||||
from mmdeploy.apis.tensorrt import TRTWrapper, load_tensorrt_plugin
|
||||
try:
|
||||
load_tensorrt_plugin()
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
warnings.warn('If input model has custom plugins, \
|
||||
you may have to build backend ops with TensorRT')
|
||||
model = TRTWrapper(trt_file)
|
||||
self.model = model
|
||||
self.output_name = self.model.output_names[0]
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
input_data = imgs[0].contiguous()
|
||||
with torch.cuda.device(self.device_id), torch.no_grad():
|
||||
seg_pred = self.model({'input': input_data})[self.output_name]
|
||||
seg_pred = seg_pred.detach().cpu().numpy()
|
||||
return seg_pred
|
22
mmdeploy/mmseg/export/onnx_helper.py
Normal file
22
mmdeploy/mmseg/export/onnx_helper.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
|
||||
def convert_syncbatchnorm(module):
|
||||
module_output = module
|
||||
if isinstance(module, torch.nn.SyncBatchNorm):
|
||||
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
|
||||
module.momentum, module.affine,
|
||||
module.track_running_stats)
|
||||
if module.affine:
|
||||
module_output.weight.data = module.weight.data.clone().detach()
|
||||
module_output.bias.data = module.bias.data.clone().detach()
|
||||
# keep requires_grad unchanged
|
||||
module_output.weight.requires_grad = module.weight.requires_grad
|
||||
module_output.bias.requires_grad = module.bias.requires_grad
|
||||
module_output.running_mean = module.running_mean
|
||||
module_output.running_var = module.running_var
|
||||
module_output.num_batches_tracked = module.num_batches_tracked
|
||||
for name, child in module.named_children():
|
||||
module_output.add_module(name, convert_syncbatchnorm(child))
|
||||
del module
|
||||
return module_output
|
47
mmdeploy/mmseg/export/prepare_input.py
Normal file
47
mmdeploy/mmseg/export/prepare_input.py
Normal file
@ -0,0 +1,47 @@
|
||||
from typing import Any, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmseg.apis.inference import LoadImage
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
|
||||
|
||||
def create_input(model_cfg: Union[str, mmcv.Config],
|
||||
imgs: Any,
|
||||
device: str = 'cuda:0'):
|
||||
if isinstance(model_cfg, str):
|
||||
model_cfg = mmcv.Config.fromfile(model_cfg)
|
||||
elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)):
|
||||
raise TypeError('config must be a filename or Config object, '
|
||||
f'but got {type(model_cfg)}')
|
||||
cfg = model_cfg.copy()
|
||||
|
||||
if not isinstance(imgs, (list, tuple)):
|
||||
imgs = [imgs]
|
||||
|
||||
if isinstance(imgs[0], np.ndarray):
|
||||
cfg = cfg.copy()
|
||||
# set loading pipeline type
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
|
||||
cfg.data.test.pipeline[1]['transforms'][0]['keep_ratio'] = False
|
||||
cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:]
|
||||
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
datas = []
|
||||
for img in imgs:
|
||||
# prepare data
|
||||
data = dict(img=img)
|
||||
# build the data pipeline
|
||||
data = test_pipeline(data)
|
||||
datas.append(data)
|
||||
|
||||
data = collate(datas, samples_per_gpu=len(imgs))
|
||||
|
||||
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
||||
data['img'] = [img.data[0][None, :] for img in data['img']]
|
||||
if device != 'cpu':
|
||||
data = scatter(data, [device])[0]
|
||||
|
||||
return data, data['img']
|
2
mmdeploy/mmseg/models/__init__.py
Normal file
2
mmdeploy/mmseg/models/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .decode_heads import * # noqa: F401,F403
|
||||
from .segmentors import * # noqa: F401,F403
|
4
mmdeploy/mmseg/models/decode_heads/__init__.py
Normal file
4
mmdeploy/mmseg/models/decode_heads/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .aspp_head import forward_of_aspp_head
|
||||
from .psp_head import forward_of_ppm
|
||||
|
||||
__all__ = ['forward_of_aspp_head', 'forward_of_ppm']
|
30
mmdeploy/mmseg/models/decode_heads/aspp_head.py
Normal file
30
mmdeploy/mmseg/models/decode_heads/aspp_head.py
Normal file
@ -0,0 +1,30 @@
|
||||
import torch
|
||||
from mmseg.ops import resize
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.ASPPHead.forward')
|
||||
def forward_of_aspp_head(ctx, self, inputs):
|
||||
x = self._transform_inputs(inputs)
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
size = x.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
size = [int(val) for val in size]
|
||||
|
||||
aspp_outs = [
|
||||
resize(
|
||||
self.image_pool(x),
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
]
|
||||
aspp_outs.extend(self.aspp_modules(x))
|
||||
aspp_outs = torch.cat(aspp_outs, dim=1)
|
||||
output = self.bottleneck(aspp_outs)
|
||||
output = self.cls_seg(output)
|
||||
return output
|
26
mmdeploy/mmseg/models/decode_heads/psp_head.py
Normal file
26
mmdeploy/mmseg/models/decode_heads/psp_head.py
Normal file
@ -0,0 +1,26 @@
|
||||
from mmseg.ops import resize
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.decode_heads.psp_head.PPM.forward')
|
||||
def forward_of_ppm(ctx, self, x):
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
size = x.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
size = [int(val) for val in size]
|
||||
|
||||
ppm_outs = []
|
||||
for ppm in self:
|
||||
ppm_out = ppm(x)
|
||||
upsampled_ppm_out = resize(
|
||||
ppm_out,
|
||||
size=size,
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
ppm_outs.append(upsampled_ppm_out)
|
||||
return ppm_outs
|
4
mmdeploy/mmseg/models/segmentors/__init__.py
Normal file
4
mmdeploy/mmseg/models/segmentors/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .base import forward_of_base_segmentor
|
||||
from .encoder_decoder import simple_test_of_encoder_decoder
|
||||
|
||||
__all__ = ['forward_of_base_segmentor', 'simple_test_of_encoder_decoder']
|
22
mmdeploy/mmseg/models/segmentors/base.py
Normal file
22
mmdeploy/mmseg/models/segmentors/base.py
Normal file
@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.BaseSegmentor.forward')
|
||||
def forward_of_base_segmentor(ctx, self, img, img_metas=None, **kwargs):
|
||||
if img_metas is None:
|
||||
img_metas = {}
|
||||
assert isinstance(img_metas, dict)
|
||||
assert isinstance(img, torch.Tensor)
|
||||
|
||||
deploy_cfg = ctx.cfg
|
||||
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
|
||||
# get origin input shape as tensor to support onnx dynamic shape
|
||||
img_shape = img.shape[2:]
|
||||
if not is_dynamic_flag:
|
||||
img_shape = [int(val) for val in img_shape]
|
||||
img_metas['img_shape'] = img_shape
|
||||
return self.simple_test(img, img_metas, **kwargs)
|
21
mmdeploy/mmseg/models/segmentors/encoder_decoder.py
Normal file
21
mmdeploy/mmseg/models/segmentors/encoder_decoder.py
Normal file
@ -0,0 +1,21 @@
|
||||
import torch.nn.functional as F
|
||||
from mmseg.ops import resize
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.simple_test')
|
||||
def simple_test_of_encoder_decoder(ctx, self, img, img_meta, **kwargs):
|
||||
x = self.extract_feat(img)
|
||||
seg_logit = self._decode_head_forward_test(x, img_meta)
|
||||
seg_logit = resize(
|
||||
input=seg_logit,
|
||||
size=img_meta['img_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
seg_logit = F.softmax(seg_logit, dim=1)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
# our inference backend only support 4D output
|
||||
seg_pred = seg_pred.unsqueeze(1)
|
||||
return seg_pred
|
Loading…
x
Reference in New Issue
Block a user