From 5ebd10b029cd2efbf3d446b4ea03afa0eda5011d Mon Sep 17 00:00:00 2001
From: AllentDan <41138331+AllentDan@users.noreply.github.com>
Date: Thu, 27 Apr 2023 16:01:45 +0800
Subject: [PATCH] Add with argmax in config for mmseg (#2038)

* add with_argmax for model conversion in mmseg

* resolve lint
---
 .../segmentation_rknn-fp16_static-320x320.py  |  2 +-
 .../segmentation_rknn-int8_static-320x320.py  |  2 +-
 configs/mmseg/segmentation_static.py          |  2 +-
 docs/en/04-supported-codebases/mmseg.md       |  2 ++
 docs/zh_cn/04-supported-codebases/mmseg.md    |  2 ++
 .../codebase/mmseg/deploy/segmentation.py     |  6 +++-
 .../mmseg/deploy/segmentation_model.py        | 36 ++-----------------
 .../models/segmentors/encoder_decoder.py      | 31 +++-------------
 8 files changed, 20 insertions(+), 63 deletions(-)

diff --git a/configs/mmseg/segmentation_rknn-fp16_static-320x320.py b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py
index 31a8edf71..3ab538fc5 100644
--- a/configs/mmseg/segmentation_rknn-fp16_static-320x320.py
+++ b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py
@@ -2,7 +2,7 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
 
 onnx_config = dict(input_shape=[320, 320])
 
-codebase_config = dict(model_type='rknn')
+codebase_config = dict(with_argmax=False)
 
 backend_config = dict(
     input_size_list=[[3, 320, 320]],
diff --git a/configs/mmseg/segmentation_rknn-int8_static-320x320.py b/configs/mmseg/segmentation_rknn-int8_static-320x320.py
index 2bb908234..ccf20f728 100644
--- a/configs/mmseg/segmentation_rknn-int8_static-320x320.py
+++ b/configs/mmseg/segmentation_rknn-int8_static-320x320.py
@@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
 
 onnx_config = dict(input_shape=[320, 320])
 
-codebase_config = dict(model_type='rknn')
+codebase_config = dict(with_argmax=False)
 
 backend_config = dict(input_size_list=[[3, 320, 320]])
diff --git a/configs/mmseg/segmentation_static.py b/configs/mmseg/segmentation_static.py
index 434a8fae9..416b781ae 100644
--- a/configs/mmseg/segmentation_static.py
+++ b/configs/mmseg/segmentation_static.py
@@ -1,2 +1,2 @@
 _base_ = ['../_base_/onnx_config.py']
-codebase_config = dict(type='mmseg', task='Segmentation')
+codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)
diff --git a/docs/en/04-supported-codebases/mmseg.md b/docs/en/04-supported-codebases/mmseg.md
index de7469556..27af0bd81 100644
--- a/docs/en/04-supported-codebases/mmseg.md
+++ b/docs/en/04-supported-codebases/mmseg.md
@@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
 - <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends.
 
 - For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
+
+- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.
diff --git a/docs/zh_cn/04-supported-codebases/mmseg.md b/docs/zh_cn/04-supported-codebases/mmseg.md
index 3e1a03a36..c5b0293df 100644
--- a/docs/zh_cn/04-supported-codebases/mmseg.md
+++ b/docs/zh_cn/04-supported-codebases/mmseg.md
@@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img)
 - <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
 
 - 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
+
+- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py
index 2dbef4d49..3b8adec1e 100644
--- a/mmdeploy/codebase/mmseg/deploy/segmentation.py
+++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py
@@ -13,7 +13,8 @@ from mmengine.model import BaseDataPreprocessor
 from mmengine.registry import Registry
 
 from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
-from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
+from mmdeploy.utils import (Codebase, Task, get_codebase_config,
+                            get_input_shape, get_root_logger)
 
 
 def process_model_config(model_cfg: mmengine.Config,
@@ -303,6 +304,9 @@ class Segmentation(BaseTask):
         if isinstance(params, list):
             params = params[-1]
         postprocess = dict(params=params, type='ResizeMask')
+        with_argmax = get_codebase_config(self.deploy_cfg).get(
+            'with_argmax', True)
+        postprocess['with_argmax'] = with_argmax
         return postprocess
 
     def get_model_name(self, *args, **kwargs) -> str:
diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py
index a8cc02290..cc87aa721 100644
--- a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py
+++ b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py
@@ -105,6 +105,9 @@ class End2EndModel(BaseBackendModel):
         for seg_pred, data_sample in zip(batch_outputs, data_samples):
             # resize seg_pred to original image shape
             metainfo = data_sample.metainfo
+            if get_codebase_config(self.deploy_cfg).get('with_argmax',
+                                                        True) is False:
+                seg_pred = seg_pred.argmax(dim=0, keepdim=True)
             if metainfo['ori_shape'] != metainfo['img_shape']:
                 from mmseg.models.utils import resize
                 ori_type = seg_pred.dtype
@@ -119,39 +122,6 @@ class End2EndModel(BaseBackendModel):
         return predictions
 
 
-@__BACKEND_MODEL.register_module('rknn')
-class RKNNModel(End2EndModel):
-    """SDK inference class, converts RKNN output to mmseg format."""
-
-    def forward(self,
-                inputs: torch.Tensor,
-                data_samples: Optional[List[BaseDataElement]] = None,
-                mode: str = 'predict'):
-        """Run forward inference.
-
-        Args:
-            inputs (Tensor): Inputs with shape (N, C, H, W).
-            data_samples (list[:obj:`SegDataSample`]): The seg data
-                samples. It usually includes information such as
-                `metainfo` and `gt_sem_seg`. Default to None.
-
-        Returns:
-            list: A list contains predictions.
-        """
-        assert mode == 'predict', \
-            'Backend model only support mode==predict,' f' but get {mode}'
-        if inputs.device != torch.device(self.device):
-            get_root_logger().warning(f'expect input device {self.device}'
-                                      f' but get {inputs.device}.')
-        inputs = inputs.to(self.device)
-        batch_outputs = self.wrapper({self.input_name: inputs})
-        batch_outputs = [
-            output.argmax(dim=1, keepdim=True)
-            for output in batch_outputs.values()
-        ]
-        return self.pack_result(batch_outputs[0], data_samples)
-
-
 @__BACKEND_MODEL.register_module('vacc_seg')
 class VACCModel(End2EndModel):
     """SDK inference class, converts VACC output to mmseg format."""
diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py
index f83fb9d0e..312c0fbfe 100644
--- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py
+++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from mmdeploy.core import FUNCTION_REWRITER, mark
-from mmdeploy.utils.constants import Backend
+from mmdeploy.utils import get_codebase_config
 
 
 @FUNCTION_REWRITER.register_rewriter(
@@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
     x = self.extract_feat(inputs)
     seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
 
+    ctx = FUNCTION_REWRITER.get_context()
+    if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
+        return seg_logit
+
     # mark seg_head
     @mark('decode_head', outputs=['output'])
     def __mark_seg_logit(seg_logit):
@@ -35,28 +39,3 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
 
     seg_pred = seg_logit.argmax(dim=1, keepdim=True)
     return seg_pred
-
-
-@FUNCTION_REWRITER.register_rewriter(
-    func_name='mmseg.models.segmentors.EncoderDecoder.predict',
-    backend=Backend.RKNN.value)
-def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
-    """Rewrite `predict` for RKNN backend.
-
-    Early return to avoid argmax operator.
-
-    Args:
-        ctx (ContextCaller): The context with additional information.
-        self: The instance of the original class.
-        inputs (Tensor): Inputs with shape (N, C, H, W).
-        data_samples (SampleList): The seg data samples.
-
-    Returns:
-        torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
-    """
-    batch_img_metas = []
-    for data_sample in data_samples:
-        batch_img_metas.append(data_sample.metainfo)
-    x = self.extract_feat(inputs)
-    seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
-    return seg_logit