Add with argmax in config for mmseg (#2038)
* add with_argmax for model conversion in mmseg * resolve lintpull/2033/head
parent
e9c0092b87
commit
5ebd10b029
|
@ -2,7 +2,7 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
|
|||
|
||||
onnx_config = dict(input_shape=[320, 320])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
codebase_config = dict(with_argmax=False)
|
||||
|
||||
backend_config = dict(
|
||||
input_size_list=[[3, 320, 320]],
|
||||
|
|
|
@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
|
|||
|
||||
onnx_config = dict(input_shape=[320, 320])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
codebase_config = dict(with_argmax=False)
|
||||
|
||||
backend_config = dict(input_size_list=[[3, 320, 320]])
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
_base_ = ['../_base_/onnx_config.py']
|
||||
codebase_config = dict(type='mmseg', task='Segmentation')
|
||||
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)
|
||||
|
|
|
@ -231,3 +231,5 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
|
|||
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) is not supported by most inference backends.
|
||||
|
||||
- For models that only supports static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
|
||||
|
||||
- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.
|
||||
|
|
|
@ -235,3 +235,5 @@ cv2.imwrite('output_segmentation.png', img)
|
|||
- <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/0c87f7a0c9099844eff8e90fa3db5b0d0ca02fee/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
|
||||
|
||||
- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
|
||||
|
||||
- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
|
||||
|
|
|
@ -13,7 +13,8 @@ from mmengine.model import BaseDataPreprocessor
|
|||
from mmengine.registry import Registry
|
||||
|
||||
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
|
||||
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
|
||||
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
|
||||
get_input_shape, get_root_logger)
|
||||
|
||||
|
||||
def process_model_config(model_cfg: mmengine.Config,
|
||||
|
@ -303,6 +304,9 @@ class Segmentation(BaseTask):
|
|||
if isinstance(params, list):
|
||||
params = params[-1]
|
||||
postprocess = dict(params=params, type='ResizeMask')
|
||||
with_argmax = get_codebase_config(self.deploy_cfg).get(
|
||||
'with_argmax', True)
|
||||
postprocess['with_argmax'] = with_argmax
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self, *args, **kwargs) -> str:
|
||||
|
|
|
@ -105,6 +105,9 @@ class End2EndModel(BaseBackendModel):
|
|||
for seg_pred, data_sample in zip(batch_outputs, data_samples):
|
||||
# resize seg_pred to original image shape
|
||||
metainfo = data_sample.metainfo
|
||||
if get_codebase_config(self.deploy_cfg).get('with_argmax',
|
||||
True) is False:
|
||||
seg_pred = seg_pred.argmax(dim=0, keepdim=True)
|
||||
if metainfo['ori_shape'] != metainfo['img_shape']:
|
||||
from mmseg.models.utils import resize
|
||||
ori_type = seg_pred.dtype
|
||||
|
@ -119,39 +122,6 @@ class End2EndModel(BaseBackendModel):
|
|||
return predictions
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('rknn')
|
||||
class RKNNModel(End2EndModel):
|
||||
"""SDK inference class, converts RKNN output to mmseg format."""
|
||||
|
||||
def forward(self,
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'predict'):
|
||||
"""Run forward inference.
|
||||
|
||||
Args:
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (list[:obj:`SegDataSample`]): The seg data
|
||||
samples. It usually includes information such as
|
||||
`metainfo` and `gt_sem_seg`. Default to None.
|
||||
|
||||
Returns:
|
||||
list: A list contains predictions.
|
||||
"""
|
||||
assert mode == 'predict', \
|
||||
'Backend model only support mode==predict,' f' but get {mode}'
|
||||
if inputs.device != torch.device(self.device):
|
||||
get_root_logger().warning(f'expect input device {self.device}'
|
||||
f' but get {inputs.device}.')
|
||||
inputs = inputs.to(self.device)
|
||||
batch_outputs = self.wrapper({self.input_name: inputs})
|
||||
batch_outputs = [
|
||||
output.argmax(dim=1, keepdim=True)
|
||||
for output in batch_outputs.values()
|
||||
]
|
||||
return self.pack_result(batch_outputs[0], data_samples)
|
||||
|
||||
|
||||
@__BACKEND_MODEL.register_module('vacc_seg')
|
||||
class VACCModel(End2EndModel):
|
||||
"""SDK inference class, converts VACC output to mmseg format."""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.utils.constants import Backend
|
||||
from mmdeploy.utils import get_codebase_config
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
|
@ -26,6 +26,10 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
|||
x = self.extract_feat(inputs)
|
||||
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
||||
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
|
||||
return seg_logit
|
||||
|
||||
# mark seg_head
|
||||
@mark('decode_head', outputs=['output'])
|
||||
def __mark_seg_logit(seg_logit):
|
||||
|
@ -35,28 +39,3 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
|||
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmseg.models.segmentors.EncoderDecoder.predict',
|
||||
backend=Backend.RKNN.value)
|
||||
def encoder_decoder__predict__rknn(self, inputs, data_samples, **kwargs):
|
||||
"""Rewrite `predict` for RKNN backend.
|
||||
|
||||
Early return to avoid argmax operator.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self: The instance of the original class.
|
||||
inputs (Tensor): Inputs with shape (N, C, H, W).
|
||||
data_samples (SampleList): The seg data samples.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
"""
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
x = self.extract_feat(inputs)
|
||||
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
||||
return seg_logit
|
||||
|
|
Loading…
Reference in New Issue