diff --git a/.isort.cfg b/.isort.cfg index acb76d8ad..61a8a593a 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,2 @@ [settings] -known_third_party = mmcls,mmcv,mmdet,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch +known_third_party = mmcls,mmcv,mmdet,mmseg,numpy,onnx,packaging,pyppl,pytest,setuptools,tensorrt,torch diff --git a/configs/mmseg/base.py b/configs/mmseg/base.py new file mode 100644 index 000000000..563c73ddc --- /dev/null +++ b/configs/mmseg/base.py @@ -0,0 +1,18 @@ +_base_ = ['../_base_/torch2onnx.py'] +codebase = 'mmseg' +pytorch2onnx = dict( + input_names=['input'], + output_names=['output'], + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + }, +) diff --git a/configs/mmseg/onnxruntime.py b/configs/mmseg/onnxruntime.py new file mode 100644 index 000000000..83544d08b --- /dev/null +++ b/configs/mmseg/onnxruntime.py @@ -0,0 +1 @@ +_base_ = ['./base.py', '../_base_/backends/onnxruntime.py'] diff --git a/configs/mmseg/tensorrt.py b/configs/mmseg/tensorrt.py new file mode 100644 index 000000000..628293644 --- /dev/null +++ b/configs/mmseg/tensorrt.py @@ -0,0 +1,7 @@ +_base_ = ['./base.py', '../_base_/backends/tensorrt.py'] +tensorrt_params = dict(model_params=[ + dict( + opt_shape_dict=dict( + input=[[1, 3, 512, 512], [1, 3, 1024, 2048], [1, 3, 2048, 2048]]), + max_workspace_size=1 << 30) +]) diff --git a/mmdeploy/apis/test.py b/mmdeploy/apis/test.py index 8350900d8..a94d9981c 100644 --- a/mmdeploy/apis/test.py +++ b/mmdeploy/apis/test.py @@ -10,13 +10,13 @@ from mmdeploy.apis.utils import assert_module_exist def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]): + assert_module_exist(codebase) # load model_cfg if necessary if isinstance(model_cfg, str): model_cfg = mmcv.Config.fromfile(model_cfg) if codebase == 'mmcls': from mmcls.datasets import (build_dataloader, build_dataset) - assert_module_exist(codebase) # build dataset and dataloader dataset = build_dataset(model_cfg.data.test) data_loader = build_dataloader( @@ -27,7 +27,6 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]): round_up=False) elif codebase == 'mmdet': - assert_module_exist(codebase) from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) # in case the test dataset is concatenated @@ -58,7 +57,17 @@ def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]): workers_per_gpu=model_cfg.data.workers_per_gpu, dist=False, shuffle=False) - + elif codebase == 'mmseg': + from mmseg.datasets import build_dataset, build_dataloader + model_cfg.data.test.test_mode = True + dataset = build_dataset(model_cfg.data.test) + samples_per_gpu = 1 + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=model_cfg.data.workers_per_gpu, + dist=False, + shuffle=False) else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -71,16 +80,18 @@ def single_gpu_test(codebase: str, show: bool = False, out_dir: Any = None, show_score_thr: float = 0.3): + assert_module_exist(codebase) + if codebase == 'mmcls': - assert_module_exist(codebase) from mmcls.apis import single_gpu_test outputs = single_gpu_test(model, data_loader, show, out_dir) elif codebase == 'mmdet': - assert_module_exist(codebase) from mmdet.apis import single_gpu_test outputs = single_gpu_test(model, data_loader, show, out_dir, show_score_thr) - + elif codebase == 'mmseg': + from mmseg.apis import single_gpu_test + outputs = single_gpu_test(model, data_loader, show, out_dir) else: raise NotImplementedError(f'Unknown codebase type: {codebase}') return outputs @@ -138,5 +149,15 @@ def post_process_outputs(outputs, eval_kwargs.update(dict(metric=metrics, **kwargs)) print(dataset.evaluate(outputs, **eval_kwargs)) + elif codebase == 'mmseg': + if out: + print(f'\nwriting results to {out}') + mmcv.dump(outputs, out) + kwargs = {} if metric_options is None else metric_options + if format_only: + dataset.format_results(outputs, **kwargs) + if metrics: + dataset.evaluate(outputs, metrics, **kwargs) + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') diff --git a/mmdeploy/apis/utils.py b/mmdeploy/apis/utils.py index 4ede72259..9fd4ac5cb 100644 --- a/mmdeploy/apis/utils.py +++ b/mmdeploy/apis/utils.py @@ -31,49 +31,50 @@ def init_model(codebase: str, model_checkpoint: Optional[str] = None, device: str = 'cuda:0', cfg_options: Optional[Dict] = None): - # mmcls + assert_module_exist(codebase) if codebase == 'mmcls': - assert_module_exist(codebase) from mmcls.apis import init_model model = init_model(model_cfg, model_checkpoint, device, cfg_options) elif codebase == 'mmdet': - assert_module_exist(codebase) from mmdet.apis import init_detector model = init_detector(model_cfg, model_checkpoint, device, cfg_options) elif codebase == 'mmseg': - assert_module_exist(codebase) from mmseg.apis import init_segmentor + from mmdeploy.mmseg.export import convert_syncbatchnorm model = init_segmentor(model_cfg, model_checkpoint, device) + model = convert_syncbatchnorm(model) else: raise NotImplementedError(f'Unknown codebase type: {codebase}') - return model + return model.eval() def create_input(codebase: str, model_cfg: Union[str, mmcv.Config], imgs: Any, device: str = 'cuda:0'): + assert_module_exist(codebase) if isinstance(model_cfg, str): model_cfg = mmcv.Config.fromfile(model_cfg) elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): raise TypeError('config must be a filename or Config object, ' f'but got {type(model_cfg)}') - cfg = model_cfg.copy() if codebase == 'mmcls': - assert_module_exist(codebase) from mmdeploy.mmcls.export import create_input return create_input(cfg, imgs, device) elif codebase == 'mmdet': - assert_module_exist(codebase) from mmdeploy.mmdet.export import create_input return create_input(cfg, imgs, device) + elif codebase == 'mmseg': + from mmdeploy.mmseg.export import create_input + return create_input(cfg, imgs, device) + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -94,8 +95,8 @@ def init_backend_model(model_files: Sequence[str], backend: str, class_names: Sequence[str], device_id: int = 0): + assert_module_exist(codebase) if codebase == 'mmcls': - assert_module_exist(codebase) if backend == 'onnxruntime': from mmdeploy.mmcls.export import ONNXRuntimeClassifier backend_model = ONNXRuntimeClassifier( @@ -120,7 +121,6 @@ def init_backend_model(model_files: Sequence[str], return backend_model elif codebase == 'mmdet': - assert_module_exist(codebase) if backend == 'onnxruntime': from mmdeploy.mmdet.export import ONNXRuntimeDetector backend_model = ONNXRuntimeDetector( @@ -137,61 +137,55 @@ def init_backend_model(model_files: Sequence[str], raise NotImplementedError(f'Unsupported backend type: {backend}') return backend_model + elif codebase == 'mmseg': + if backend == 'onnxruntime': + from mmdeploy.mmseg.export import ONNXRuntimeSegmentor + backend_model = ONNXRuntimeSegmentor( + model_files[0], class_names=class_names, device_id=device_id) + elif backend == 'tensorrt': + from mmdeploy.mmseg.export import TensorRTSegmentor + backend_model = TensorRTSegmentor( + model_files[0], class_names=class_names, device_id=device_id) + else: + raise NotImplementedError(f'Unsupported backend type: {backend}') + return backend_model else: raise NotImplementedError(f'Unknown codebase type: {codebase}') def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]): + assert_module_exist(codebase) + model_cfg_str = model_cfg + if isinstance(model_cfg, str): + model_cfg = mmcv.Config.fromfile(model_cfg) + elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(model_cfg)}') + if codebase == 'mmcls': - assert_module_exist(codebase) - if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) - elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): - raise TypeError('config must be a filename or Config object, ' - f'but got {type(model_cfg)}') - from mmcls.datasets import DATASETS - module_dict = DATASETS.module_dict - data_cfg = model_cfg.data - - if 'train' in data_cfg: - module = module_dict[data_cfg.train.type] - elif 'val' in data_cfg: - module = module_dict[data_cfg.val.type] - elif 'test' in data_cfg: - module = module_dict[data_cfg.test.type] - else: - raise RuntimeError(f'No dataset config found in: {model_cfg_str}') - - return module.CLASSES - - if codebase == 'mmdet': - assert_module_exist(codebase) - if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) - elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): - raise TypeError('config must be a filename or Config object, ' - f'but got {type(model_cfg)}') - + elif codebase == 'mmdet': from mmdet.datasets import DATASETS - module_dict = DATASETS.module_dict - data_cfg = model_cfg.data - - if 'train' in data_cfg: - module = module_dict[data_cfg.train.type] - elif 'val' in data_cfg: - module = module_dict[data_cfg.val.type] - elif 'test' in data_cfg: - module = module_dict[data_cfg.test.type] - else: - raise RuntimeError(f'No dataset config found in: {model_cfg_str}') - - return module.CLASSES - + elif codebase == 'mmseg': + from mmseg.datasets import DATASETS else: raise NotImplementedError(f'Unknown codebase type: {codebase}') + module_dict = DATASETS.module_dict + data_cfg = model_cfg.data + + if 'train' in data_cfg: + module = module_dict[data_cfg.train.type] + elif 'val' in data_cfg: + module = module_dict[data_cfg.val.type] + elif 'test' in data_cfg: + module = module_dict[data_cfg.test.type] + else: + raise RuntimeError(f'No dataset config found in: {model_cfg_str}') + + return module.CLASSES + def check_model_outputs(codebase: str, image: Union[str, np.ndarray], @@ -199,10 +193,12 @@ def check_model_outputs(codebase: str, model, output_file: str, backend: str, + dataset: str = None, show_result=False): + assert_module_exist(codebase) show_img = mmcv.imread(image) if isinstance(image, str) else image + if codebase == 'mmcls': - assert_module_exist(codebase) output_file = None if show_result else output_file with torch.no_grad(): scores = model(**model_inputs, return_loss=False)[0] @@ -221,7 +217,6 @@ def check_model_outputs(codebase: str, out_file=output_file) elif codebase == 'mmdet': - assert_module_exist(codebase) output_file = None if show_result else output_file score_thr = 0.3 with torch.no_grad(): @@ -234,5 +229,20 @@ def check_model_outputs(codebase: str, win_name=backend, out_file=output_file) + elif codebase == 'mmseg': + output_file = None if show_result else output_file + from mmseg.core.evaluation import get_palette + dataset = 'cityscapes' if dataset is None else dataset + palette = get_palette(dataset) + with torch.no_grad(): + results = model(**model_inputs, return_loss=False, rescale=True) + model.show_result( + show_img, + results, + palette=palette, + show=True, + win_name=backend, + out_file=output_file, + opacity=0.5) else: raise NotImplementedError(f'Unknown codebase type: {codebase}') diff --git a/mmdeploy/mmdet/export/model_wrappers.py b/mmdeploy/mmdet/export/model_wrappers.py index 1fe8c1c39..44ebca773 100644 --- a/mmdeploy/mmdet/export/model_wrappers.py +++ b/mmdeploy/mmdet/export/model_wrappers.py @@ -74,7 +74,8 @@ class DeployBaseDetector(BaseDetector): img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] masks = masks[:, :img_h, :img_w] - if rescale and batch_masks.shape[1] > 0: + # avoid to resize masks with zero dim + if rescale and masks.shape[0] != 0: masks = masks.astype(np.float32) masks = torch.from_numpy(masks) masks = torch.nn.functional.interpolate( diff --git a/mmdeploy/mmseg/__init__.py b/mmdeploy/mmseg/__init__.py index e69de29bb..d2b62e1cb 100644 --- a/mmdeploy/mmseg/__init__.py +++ b/mmdeploy/mmseg/__init__.py @@ -0,0 +1,2 @@ +from .export import * # noqa: F401,F403 +from .models import * # noqa: F401,F403 diff --git a/mmdeploy/mmseg/export/__init__.py b/mmdeploy/mmseg/export/__init__.py index e69de29bb..93fabdac7 100644 --- a/mmdeploy/mmseg/export/__init__.py +++ b/mmdeploy/mmseg/export/__init__.py @@ -0,0 +1,8 @@ +from .model_wrappers import ONNXRuntimeSegmentor, TensorRTSegmentor +from .onnx_helper import convert_syncbatchnorm +from .prepare_input import create_input + +__all__ = [ + 'create_input', 'ONNXRuntimeSegmentor', 'TensorRTSegmentor', + 'convert_syncbatchnorm' +] diff --git a/mmdeploy/mmseg/export/model_wrappers.py b/mmdeploy/mmseg/export/model_wrappers.py new file mode 100644 index 000000000..c9d06256d --- /dev/null +++ b/mmdeploy/mmseg/export/model_wrappers.py @@ -0,0 +1,117 @@ +import os.path as osp +import warnings +from typing import Sequence + +import numpy as np +import torch +from mmseg.models.segmentors.base import BaseSegmentor +from mmseg.ops import resize + + +class DeployBaseSegmentor(BaseSegmentor): + + def __init__(self, class_names: Sequence[str], device_id: int): + super(DeployBaseSegmentor, self).__init__(init_cfg=None) + self.CLASSES = class_names + self.device_id = device_id + self.PALETTE = None + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def encode_decode(self, img, img_metas): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def simple_test(self, img, img_meta, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def forward(self, img, img_metas, **kwargs): + seg_pred = self.forward_test(img, img_metas, **kwargs) + # whole mode supports dynamic shape + ori_shape = img_metas[0][0]['ori_shape'] + if not (ori_shape[0] == seg_pred.shape[-2] + and ori_shape[1] == seg_pred.shape[-1]): + seg_pred = torch.from_numpy(seg_pred).float() + seg_pred = resize( + seg_pred, size=tuple(ori_shape[:2]), mode='nearest') + seg_pred = seg_pred.long().detach().cpu().numpy() + # remove unnecessary dim + seg_pred = seg_pred.squeeze(1) + seg_pred = list(seg_pred) + return seg_pred + + +class ONNXRuntimeSegmentor(DeployBaseSegmentor): + + def __init__(self, onnx_file: str, class_names: Sequence[str], + device_id: int): + super(ONNXRuntimeSegmentor, self).__init__(class_names, device_id) + + import onnxruntime as ort + from mmdeploy.apis.onnxruntime import get_ops_path + + # get the custom op path + ort_custom_op_path = get_ops_path() + session_options = ort.SessionOptions() + # register custom op for onnxruntime + if osp.exists(ort_custom_op_path): + session_options.register_custom_ops_library(ort_custom_op_path) + sess = ort.InferenceSession(onnx_file, session_options) + providers = ['CPUExecutionProvider'] + options = [{}] + is_cuda_available = ort.get_device() == 'GPU' + if is_cuda_available: + providers.insert(0, 'CUDAExecutionProvider') + options.insert(0, {'device_id': device_id}) + + sess.set_providers(providers, options) + + self.sess = sess + self.io_binding = sess.io_binding() + self.output_names = [_.name for _ in sess.get_outputs()] + for name in self.output_names: + self.io_binding.bind_output(name) + + def forward_test(self, imgs, img_metas, **kwargs): + input_data = imgs[0] + device_type = input_data.device.type + self.io_binding.bind_input( + name='input', + device_type=device_type, + device_id=self.device_id, + element_type=np.float32, + shape=input_data.shape, + buffer_ptr=input_data.data_ptr()) + self.sess.run_with_iobinding(self.io_binding) + seg_pred = self.io_binding.copy_outputs_to_cpu()[0] + return seg_pred + + +class TensorRTSegmentor(DeployBaseSegmentor): + + def __init__(self, trt_file: str, class_names: Sequence[str], + device_id: int): + super(TensorRTSegmentor, self).__init__(class_names, device_id) + + from mmdeploy.apis.tensorrt import TRTWrapper, load_tensorrt_plugin + try: + load_tensorrt_plugin() + except (ImportError, ModuleNotFoundError): + warnings.warn('If input model has custom plugins, \ + you may have to build backend ops with TensorRT') + model = TRTWrapper(trt_file) + self.model = model + self.output_name = self.model.output_names[0] + + def forward_test(self, imgs, img_metas, **kwargs): + input_data = imgs[0].contiguous() + with torch.cuda.device(self.device_id), torch.no_grad(): + seg_pred = self.model({'input': input_data})[self.output_name] + seg_pred = seg_pred.detach().cpu().numpy() + return seg_pred diff --git a/mmdeploy/mmseg/export/onnx_helper.py b/mmdeploy/mmseg/export/onnx_helper.py new file mode 100644 index 000000000..fd39f0b52 --- /dev/null +++ b/mmdeploy/mmseg/export/onnx_helper.py @@ -0,0 +1,22 @@ +import torch + + +def convert_syncbatchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, convert_syncbatchnorm(child)) + del module + return module_output diff --git a/mmdeploy/mmseg/export/prepare_input.py b/mmdeploy/mmseg/export/prepare_input.py new file mode 100644 index 000000000..e4b938bae --- /dev/null +++ b/mmdeploy/mmseg/export/prepare_input.py @@ -0,0 +1,47 @@ +from typing import Any, Union + +import mmcv +import numpy as np +from mmcv.parallel import collate, scatter +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose + + +def create_input(model_cfg: Union[str, mmcv.Config], + imgs: Any, + device: str = 'cuda:0'): + if isinstance(model_cfg, str): + model_cfg = mmcv.Config.fromfile(model_cfg) + elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(model_cfg)}') + cfg = model_cfg.copy() + + if not isinstance(imgs, (list, tuple)): + imgs = [imgs] + + if isinstance(imgs[0], np.ndarray): + cfg = cfg.copy() + # set loading pipeline type + cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam' + + cfg.data.test.pipeline[1]['transforms'][0]['keep_ratio'] = False + cfg.data.test.pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + + test_pipeline = Compose(cfg.data.test.pipeline) + datas = [] + for img in imgs: + # prepare data + data = dict(img=img) + # build the data pipeline + data = test_pipeline(data) + datas.append(data) + + data = collate(datas, samples_per_gpu=len(imgs)) + + data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] + data['img'] = [img.data[0][None, :] for img in data['img']] + if device != 'cpu': + data = scatter(data, [device])[0] + + return data, data['img'] diff --git a/mmdeploy/mmseg/models/__init__.py b/mmdeploy/mmseg/models/__init__.py new file mode 100644 index 000000000..4d89a2ae7 --- /dev/null +++ b/mmdeploy/mmseg/models/__init__.py @@ -0,0 +1,2 @@ +from .decode_heads import * # noqa: F401,F403 +from .segmentors import * # noqa: F401,F403 diff --git a/mmdeploy/mmseg/models/decode_heads/__init__.py b/mmdeploy/mmseg/models/decode_heads/__init__.py new file mode 100644 index 000000000..4994e6395 --- /dev/null +++ b/mmdeploy/mmseg/models/decode_heads/__init__.py @@ -0,0 +1,4 @@ +from .aspp_head import forward_of_aspp_head +from .psp_head import forward_of_ppm + +__all__ = ['forward_of_aspp_head', 'forward_of_ppm'] diff --git a/mmdeploy/mmseg/models/decode_heads/aspp_head.py b/mmdeploy/mmseg/models/decode_heads/aspp_head.py new file mode 100644 index 000000000..535695f69 --- /dev/null +++ b/mmdeploy/mmseg/models/decode_heads/aspp_head.py @@ -0,0 +1,30 @@ +import torch +from mmseg.ops import resize + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmseg.models.decode_heads.ASPPHead.forward') +def forward_of_aspp_head(ctx, self, inputs): + x = self._transform_inputs(inputs) + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + # get origin input shape as tensor to support onnx dynamic shape + size = x.shape[2:] + if not is_dynamic_flag: + size = [int(val) for val in size] + + aspp_outs = [ + resize( + self.image_pool(x), + size=size, + mode='bilinear', + align_corners=self.align_corners) + ] + aspp_outs.extend(self.aspp_modules(x)) + aspp_outs = torch.cat(aspp_outs, dim=1) + output = self.bottleneck(aspp_outs) + output = self.cls_seg(output) + return output diff --git a/mmdeploy/mmseg/models/decode_heads/psp_head.py b/mmdeploy/mmseg/models/decode_heads/psp_head.py new file mode 100644 index 000000000..1b6cbe628 --- /dev/null +++ b/mmdeploy/mmseg/models/decode_heads/psp_head.py @@ -0,0 +1,26 @@ +from mmseg.ops import resize + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmseg.models.decode_heads.psp_head.PPM.forward') +def forward_of_ppm(ctx, self, x): + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + # get origin input shape as tensor to support onnx dynamic shape + size = x.shape[2:] + if not is_dynamic_flag: + size = [int(val) for val in size] + + ppm_outs = [] + for ppm in self: + ppm_out = ppm(x) + upsampled_ppm_out = resize( + ppm_out, + size=size, + mode='bilinear', + align_corners=self.align_corners) + ppm_outs.append(upsampled_ppm_out) + return ppm_outs diff --git a/mmdeploy/mmseg/models/segmentors/__init__.py b/mmdeploy/mmseg/models/segmentors/__init__.py new file mode 100644 index 000000000..343264f29 --- /dev/null +++ b/mmdeploy/mmseg/models/segmentors/__init__.py @@ -0,0 +1,4 @@ +from .base import forward_of_base_segmentor +from .encoder_decoder import simple_test_of_encoder_decoder + +__all__ = ['forward_of_base_segmentor', 'simple_test_of_encoder_decoder'] diff --git a/mmdeploy/mmseg/models/segmentors/base.py b/mmdeploy/mmseg/models/segmentors/base.py new file mode 100644 index 000000000..f31f441da --- /dev/null +++ b/mmdeploy/mmseg/models/segmentors/base.py @@ -0,0 +1,22 @@ +import torch + +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmseg.models.segmentors.BaseSegmentor.forward') +def forward_of_base_segmentor(ctx, self, img, img_metas=None, **kwargs): + if img_metas is None: + img_metas = {} + assert isinstance(img_metas, dict) + assert isinstance(img, torch.Tensor) + + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + # get origin input shape as tensor to support onnx dynamic shape + img_shape = img.shape[2:] + if not is_dynamic_flag: + img_shape = [int(val) for val in img_shape] + img_metas['img_shape'] = img_shape + return self.simple_test(img, img_metas, **kwargs) diff --git a/mmdeploy/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/mmseg/models/segmentors/encoder_decoder.py new file mode 100644 index 000000000..e5f77b0e3 --- /dev/null +++ b/mmdeploy/mmseg/models/segmentors/encoder_decoder.py @@ -0,0 +1,21 @@ +import torch.nn.functional as F +from mmseg.ops import resize + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmseg.models.segmentors.EncoderDecoder.simple_test') +def simple_test_of_encoder_decoder(ctx, self, img, img_meta, **kwargs): + x = self.extract_feat(img) + seg_logit = self._decode_head_forward_test(x, img_meta) + seg_logit = resize( + input=seg_logit, + size=img_meta['img_shape'], + mode='bilinear', + align_corners=self.align_corners) + seg_logit = F.softmax(seg_logit, dim=1) + seg_pred = seg_logit.argmax(dim=1) + # our inference backend only support 4D output + seg_pred = seg_pred.unsqueeze(1) + return seg_pred