[Feature] Support probability output for segmentation (#1379)
* add do_argmax flag * sdk support * update doc * replace do_argmax with with_argmax * add todopull/1577/head
parent
d113a5f1c7
commit
85b7b967ee
|
@ -2,6 +2,6 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']
|
|||
|
||||
onnx_config = dict(input_shape=[320, 320])
|
||||
|
||||
codebase_config = dict(model_type='rknn')
|
||||
codebase_config = dict(model_type='rknn', with_argmax=False)
|
||||
|
||||
backend_config = dict(input_size_list=[[3, 320, 320]])
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
_base_ = ['../_base_/onnx_config.py']
|
||||
codebase_config = dict(type='mmseg', task='Segmentation')
|
||||
codebase_config = dict(type='mmseg', task='Segmentation', with_argmax=True)
|
||||
|
|
|
@ -18,6 +18,7 @@ class ResizeMask : public MMSegmentation {
|
|||
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
|
||||
try {
|
||||
classes_ = cfg["params"]["num_classes"].get<int>();
|
||||
with_argmax_ = cfg["params"].value("with_argmax", true);
|
||||
little_endian_ = IsLittleEndian();
|
||||
} catch (const std::exception &e) {
|
||||
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
|
||||
|
@ -31,10 +32,19 @@ class ResizeMask : public MMSegmentation {
|
|||
auto mask = inference_result["output"].get<Tensor>();
|
||||
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
|
||||
mask.shape(), mask.data_type());
|
||||
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
|
||||
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
|
||||
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
if ((mask.shape(1) != 1) && with_argmax_) {
|
||||
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
|
||||
mask.shape());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
if (!with_argmax_) {
|
||||
MMDEPLOY_ERROR("TODO: SDK will support probability featmap soon.");
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto height = (int)mask.shape(2);
|
||||
auto width = (int)mask.shape(3);
|
||||
|
@ -85,6 +95,7 @@ class ResizeMask : public MMSegmentation {
|
|||
|
||||
protected:
|
||||
int classes_{};
|
||||
bool with_argmax_{true};
|
||||
bool little_endian_;
|
||||
};
|
||||
|
||||
|
|
|
@ -51,3 +51,5 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmsegmentation/bl
|
|||
- <i id="static_shape">PSPNet, Fast-SCNN</i> only support static shape, because [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) is not supported in most of backends dynamically.
|
||||
|
||||
- For models only supporting static shape, you should use the deployment config file of static shape such as `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`.
|
||||
|
||||
- For users prefer deployed models generate probability feature map, put `codebase_config = dict(with_argmax=False)` in deploy configs.
|
||||
|
|
|
@ -51,3 +51,5 @@ mmseg 是一个基于 PyTorch 的开源对象分割工具箱,也是 [OpenMMLab
|
|||
- <i id=“static_shape”>PSPNet,Fast-SCNN</i> 仅支持静态输入,因为多数推理框架的 [nn.AdaptiveAvgPool2d](https://github.com/open-mmlab/mmsegmentation/blob/97f9670c5a4a2a3b4cfb411bcc26db16b23745f7/mmseg/models/decode_heads/psp_head.py#L38) 不支持动态输入。
|
||||
|
||||
- 对于仅支持静态形状的模型,应使用静态形状的部署配置文件,例如 `configs/mmseg/segmentation_tensorrt_static-1024x2048.py`
|
||||
|
||||
- 对于喜欢部署模型生成概率特征图的用户,将 `codebase_config = dict(with_argmax=False)` 放在部署配置中就足够了。
|
||||
|
|
|
@ -8,7 +8,7 @@ from mmcv.parallel import DataContainer
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.codebase.base import BaseTask
|
||||
from mmdeploy.utils import Task, get_input_shape
|
||||
from mmdeploy.utils import Task, get_codebase_config, get_input_shape
|
||||
from .mmsegmentation import MMSEG_TASK
|
||||
|
||||
|
||||
|
@ -286,6 +286,9 @@ class Segmentation(BaseTask):
|
|||
postprocess = self.model_cfg.model.decode_head
|
||||
if isinstance(postprocess, list):
|
||||
postprocess = postprocess[-1]
|
||||
with_argmax = get_codebase_config(self.deploy_cfg).get(
|
||||
'with_argmax', True)
|
||||
postprocess['with_argmax'] = with_argmax
|
||||
return postprocess
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
|
|
|
@ -109,6 +109,11 @@ class End2EndModel(BaseBackendModel):
|
|||
"""
|
||||
outputs = self.wrapper({self.input_name: imgs})
|
||||
outputs = self.wrapper.output_to_list(outputs)
|
||||
if get_codebase_config(self.deploy_cfg).get('with_argmax',
|
||||
True) is False:
|
||||
outputs = [
|
||||
output.argmax(dim=1, keepdim=True) for output in outputs
|
||||
]
|
||||
outputs = [out.detach().cpu().numpy() for out in outputs]
|
||||
return outputs
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils.config_utils import get_codebase_config
|
||||
from mmdeploy.utils.constants import Backend
|
||||
|
||||
|
||||
|
@ -25,6 +26,8 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs):
|
|||
"""
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
seg_logit = F.softmax(seg_logit, dim=1)
|
||||
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
|
||||
return seg_logit
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
|
||||
|
|
Loading…
Reference in New Issue