[Feature] Support probability output for segmentation (#1379)

* add do_argmax flag

* sdk support

* update doc

* replace do_argmax with with_argmax

* add todo
pull/1577/head
AllentDan 2022-12-26 15:48:07 +08:00 committed by GitHub
parent d113a5f1c7
commit 85b7b967ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 30 additions and 4 deletions

View File

@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
onnx_config = dict(input_shape=[320, 320])
codebase_config = dict(model_type='rknn')
codebase_config = dict(model_type='rknn', with_argmax=False)
backend_config = dict(input_size_list=[[3, 320, 320]])

View File

@ -1,2 +1,2 @@
_base_ = ['../_base_/onnx_config.py']
codebase_config = dict(type='mmseg', task='Segmentation')
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)

View File

@ -18,6 +18,7 @@ class ResizeMask : public MMSegmentation {
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
try {
classes_ = cfg["params"]["num_classes"].get<int>();
with_argmax_ = cfg["params"].value("with_argmax", true);
little_endian_ = IsLittleEndian();
} catch (const std::exception &e) {
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
@ -31,10 +32,19 @@ class ResizeMask : public MMSegmentation {
auto mask = inference_result["output"].get<Tensor>();
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
mask.shape(), mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}
if ((mask.shape(1) != 1) && with_argmax_) {
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
mask.shape());
return Status(eNotSupported);
}
if (!with_argmax_) {
MMDEPLOY_ERROR("TODO: SDK will support probability featmap soon.");
return Status(eNotSupported);
}
auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
@ -85,6 +95,7 @@ class ResizeMask : public MMSegmentation {
protected:
int classes_{};
bool with_argmax_{true};
bool little_endian_;
};

View File

@ -51,3 +51,5 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) is not supported in most of backends dynamically.
- For models only supporting static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.

View File

@ -51,3 +51,5 @@ mmseg 是一个基于 PyTorch 的开源对象分割工具箱,也是 [OpenMMLab
- <i id=“static_shape”>PSPNetFast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。

View File

@ -8,7 +8,7 @@ from mmcv.parallel import DataContainer
from torch.utils.data import Dataset
from mmdeploy.codebase.base import BaseTask
from mmdeploy.utils import Task, get_input_shape
from mmdeploy.utils import Task, get_codebase_config, get_input_shape
from .mmsegmentation import MMSEG_TASK
@ -286,6 +286,9 @@ class Segmentation(BaseTask):
postprocess = self.model_cfg.model.decode_head
if isinstance(postprocess, list):
postprocess = postprocess[-1]
with_argmax = get_codebase_config(self.deploy_cfg).get(
'with_argmax', True)
postprocess['with_argmax'] = with_argmax
return postprocess
def get_model_name(self) -> str:

View File

@ -109,6 +109,11 @@ class End2EndModel(BaseBackendModel):
"""
outputs = self.wrapper({self.input_name: imgs})
outputs = self.wrapper.output_to_list(outputs)
if get_codebase_config(self.deploy_cfg).get('with_argmax',
True) is False:
outputs = [
output.argmax(dim=1, keepdim=True) for output in outputs
]
outputs = [out.detach().cpu().numpy() for out in outputs]
return outputs

View File

@ -2,6 +2,7 @@
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.config_utils import get_codebase_config
from mmdeploy.utils.constants import Backend
@ -25,6 +26,8 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
"""
seg_logit = self.encode_decode(img, img_meta)
seg_logit = F.softmax(seg_logit, dim=1)
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred