From 5ebd10b029cd2efbf3d446b4ea03afa0eda5011d Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Thu, 27 Apr 2023 16:01:45 +0800 Subject: [PATCH] Add with argmax in config for mmseg (#2038) * add with_argmax for model conversion in mmseg * resolve lint --- .../segmentation_rknn-fp16_static-320x320.py | 2 +- .../segmentation_rknn-int8_static-320x320.py | 2 +- configs/mmseg/segmentation_static.py | 2 +- docs/en/04-supported-codebases/mmseg.md | 2 ++ docs/zh_cn/04-supported-codebases/mmseg.md | 2 ++ .../codebase/mmseg/deploy/segmentation.py | 6 +++- .../mmseg/deploy/segmentation_model.py | 36 ++----------------- .../models/segmentors/encoder_decoder.py | 31 +++------------- 8 files changed, 20 insertions(+), 63 deletions(-) diff --git a/configs/mmseg/segmentation_rknn-fp16_static-320x320.py b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py index 31a8edf71..3ab538fc5 100644 --- a/configs/mmseg/segmentation_rknn-fp16_static-320x320.py +++ b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py @@ -2,7 +2,7 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py'] onnx_config = dict(input_shape=[320, 320]) -codebase_config = dict(model_type='rknn') +codebase_config = dict(with_argmax=False) backend_config = dict( input_size_list=[[3, 320, 320]], diff --git a/configs/mmseg/segmentation_rknn-int8_static-320x320.py b/configs/mmseg/segmentation_rknn-int8_static-320x320.py index 2bb908234..ccf20f728 100644 --- a/configs/mmseg/segmentation_rknn-int8_static-320x320.py +++ b/configs/mmseg/segmentation_rknn-int8_static-320x320.py @@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py'] onnx_config = dict(input_shape=[320, 320]) -codebase_config = dict(model_type='rknn') +codebase_config = dict(with_argmax=False) backend_config = dict(input_size_list=[[3, 320, 320]]) diff --git a/configs/mmseg/segmentation_static.py b/configs/mmseg/segmentation_static.py index 434a8fae9..416b781ae 100644 --- a/configs/mmseg/segmentation_static.py +++ b/configs/mmseg/segmentation_static.py @@ -1,2 +1,2 @@ _base_ = ['../_base_/onnx_config.py'] -codebase_config = dict(type='mmseg', task='Segmentation') +codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True) diff --git a/docs/en/04-supported-codebases/mmseg.md b/docs/en/04-supported-codebases/mmseg.md index de7469556..27af0bd81 100644 --- a/docs/en/04-supported-codebases/mmseg.md +++ b/docs/en/04-supported-codebases/mmseg.md @@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter - <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends. - For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`. + +- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs. diff --git a/docs/zh_cn/04-supported-codebases/mmseg.md b/docs/zh_cn/04-supported-codebases/mmseg.md index 3e1a03a36..c5b0293df 100644 --- a/docs/zh_cn/04-supported-codebases/mmseg.md +++ b/docs/zh_cn/04-supported-codebases/mmseg.md @@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img) - <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。 - 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py` + +- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。 diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index 2dbef4d49..3b8adec1e 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -13,7 +13,8 @@ from mmengine.model import BaseDataPreprocessor from mmengine.registry import Registry from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase -from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger +from mmdeploy.utils import (Codebase, Task, get_codebase_config, + get_input_shape, get_root_logger) def process_model_config(model_cfg: mmengine.Config, @@ -303,6 +304,9 @@ class Segmentation(BaseTask): if isinstance(params, list): params = params[-1] postprocess = dict(params=params, type='ResizeMask') + with_argmax = get_codebase_config(self.deploy_cfg).get( + 'with_argmax', True) + postprocess['with_argmax'] = with_argmax return postprocess def get_model_name(self, *args, **kwargs) -> str: diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py index a8cc02290..cc87aa721 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py @@ -105,6 +105,9 @@ class End2EndModel(BaseBackendModel): for seg_pred, data_sample in zip(batch_outputs, data_samples): # resize seg_pred to original image shape metainfo = data_sample.metainfo + if get_codebase_config(self.deploy_cfg).get('with_argmax', + True) is False: + seg_pred = seg_pred.argmax(dim=0, keepdim=True) if metainfo['ori_shape'] != metainfo['img_shape']: from mmseg.models.utils import resize ori_type = seg_pred.dtype @@ -119,39 +122,6 @@ class End2EndModel(BaseBackendModel): return predictions -@__BACKEND_MODEL.register_module('rknn') -class RKNNModel(End2EndModel): - """SDK inference class, converts RKNN output to mmseg format.""" - - def forward(self, - inputs: torch.Tensor, - data_samples: Optional[List[BaseDataElement]] = None, - mode: str = 'predict'): - """Run forward inference. - - Args: - inputs (Tensor): Inputs with shape (N, C, H, W). - data_samples (list[:obj:`SegDataSample`]): The seg data - samples. It usually includes information such as - `metainfo` and `gt_sem_seg`. Default to None. - - Returns: - list: A list contains predictions. - """ - assert mode == 'predict', \ - 'Backend model only support mode==predict,' f' but get {mode}' - if inputs.device != torch.device(self.device): - get_root_logger().warning(f'expect input device {self.device}' - f' but get {inputs.device}.') - inputs = inputs.to(self.device) - batch_outputs = self.wrapper({self.input_name: inputs}) - batch_outputs = [ - output.argmax(dim=1, keepdim=True) - for output in batch_outputs.values() - ] - return self.pack_result(batch_outputs[0], data_samples) - - @__BACKEND_MODEL.register_module('vacc_seg') class VACCModel(End2EndModel): """SDK inference class, converts VACC output to mmseg format.""" diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py index f83fb9d0e..312c0fbfe 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdeploy.core import FUNCTION_REWRITER, mark -from mmdeploy.utils.constants import Backend +from mmdeploy.utils import get_codebase_config @FUNCTION_REWRITER.register_rewriter( @@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs): x = self.extract_feat(inputs) seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg) + ctx = FUNCTION_REWRITER.get_context() + if get_codebase_config(ctx.cfg).get('with_argmax', True) is False: + return seg_logit + # mark seg_head @mark('decode_head', outputs=['output']) def __mark_seg_logit(seg_logit): @@ -35,28 +39,3 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs): seg_pred = seg_logit.argmax(dim=1, keepdim=True) return seg_pred - - -@FUNCTION_REWRITER.register_rewriter( - func_name='mmseg.models.segmentors.EncoderDecoder.predict', - backend=Backend.RKNN.value) -def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs): - """Rewrite `predict` for RKNN backend. - - Early return to avoid argmax operator. - - Args: - ctx (ContextCaller): The context with additional information. - self: The instance of the original class. - inputs (Tensor): Inputs with shape (N, C, H, W). - data_samples (SampleList): The seg data samples. - - Returns: - torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. - """ - batch_img_metas = [] - for data_sample in data_samples: - batch_img_metas.append(data_sample.metainfo) - x = self.extract_feat(inputs) - seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg) - return seg_logit