Add with argmax in config for mmseg (#2038)

* add with_argmax for model conversion in mmseg

* resolve lint
pull/2033/head
AllentDan 2023-04-27 16:01:45 +08:00 committed by GitHub
parent e9c0092b87
commit 5ebd10b029
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 20 additions and 63 deletions

View File

@ -2,7 +2,7 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
codebase_config = dict(with_argmax=False)
backend_config = dict(
input_size_list=[[3, 320, 320]],

View File

@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
codebase_config = dict(with_argmax=False)
backend_config = dict(input_size_list=[[3, 320, 320]])

View File

@ -1,2 +1,2 @@
_base_ = ['../_base_/onnx_config.py']
codebase_config = dict(type='mmseg', task='Segmentation')
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)

View File

@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends.
- For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.

View File

@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img)
- <i id=“static_shape”>PSPNetFast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。

View File

@ -13,7 +13,8 @@ from mmengine.model import BaseDataPreprocessor
from mmengine.registry import Registry
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
get_input_shape, get_root_logger)
def process_model_config(model_cfg: mmengine.Config,
@ -303,6 +304,9 @@ class Segmentation(BaseTask):
if isinstance(params, list):
params = params[-1]
postprocess = dict(params=params, type='ResizeMask')
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
postprocess['with_argmax'] = with_argmax
return postprocess
def get_model_name(self, *args, **kwargs) -> str:

View File

@ -105,6 +105,9 @@ class End2EndModel(BaseBackendModel):
for seg_pred, data_sample in zip(batch_outputs, data_samples):
# resize seg_pred to original image shape
metainfo = data_sample.metainfo
if get_codebase_config(self.deploy_cfg).get('with_argmax',
True) is False:
seg_pred = seg_pred.argmax(dim=0, keepdim=True)
if metainfo['ori_shape'] != metainfo['img_shape']:
from mmseg.models.utils import resize
ori_type = seg_pred.dtype
@ -119,39 +122,6 @@ class End2EndModel(BaseBackendModel):
return predictions
@__BACKEND_MODEL.register_module('rknn')
class RKNNModel(End2EndModel):
"""SDK inference class, converts RKNN output to mmseg format."""
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'predict'):
"""Run forward inference.
Args:
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (list[:obj:`SegDataSample`]): The seg data
samples. It usually includes information such as
`metainfo` and `gt_sem_seg`. Default to None.
Returns:
list: A list contains predictions.
"""
assert mode == 'predict', \
'Backend model only support mode==predict,' f' but get {mode}'
if inputs.device != torch.device(self.device):
get_root_logger().warning(f'expect input device {self.device}'
f' but get {inputs.device}.')
inputs = inputs.to(self.device)
batch_outputs = self.wrapper({self.input_name: inputs})
batch_outputs = [
output.argmax(dim=1, keepdim=True)
for output in batch_outputs.values()
]
return self.pack_result(batch_outputs[0], data_samples)
@__BACKEND_MODEL.register_module('vacc_seg')
class VACCModel(End2EndModel):
"""SDK inference class, converts VACC output to mmseg format."""

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils.constants import Backend
from mmdeploy.utils import get_codebase_config
@FUNCTION_REWRITER.register_rewriter(
@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
ctx = FUNCTION_REWRITER.get_context()
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
@ -35,28 +39,3 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred
@FUNCTION_REWRITER.register_rewriter(
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
backend=Backend.RKNN.value)
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
"""Rewrite `predict` for RKNN backend.
Early return to avoid argmax operator.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
inputs (Tensor): Inputs with shape (N, C, H, W).
data_samples (SampleList): The seg data samples.
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
"""
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
return seg_logit