Merge remote-tracking branch 'upstream/main' into fix_ut

This commit is contained in:
RunningLeon 2023-06-25 15:11:39 +08:00
commit 9e7a187f5c
10 changed files with 82 additions and 30 deletions

View File

@ -18,7 +18,7 @@ DEFINE_int32(flip, 0, "Set to 1 for flipping the input horizontally");
DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable"); DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable");
DEFINE_string(skeleton, "coco", DEFINE_string(skeleton, "coco",
R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody")"); R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody", "coco-wholebody-hand")");
DEFINE_string(background, "default", DEFINE_string(background, "default",
R"(Output background, "default": original image, "black": black background)"); R"(Output background, "default": original image, "black": black background)");

View File

@ -73,6 +73,25 @@ const Skeleton& gSkeletonCocoWholeBody() {
return inst; return inst;
} }
const Skeleton& gSkeletonCocoWholeBodyHand() {
static const Skeleton inst{
{
{0, 1}, {1, 2}, {2, 3}, {3, 4},
{0, 5}, {5, 6}, {6, 7}, {7, 8},
{0, 9}, {9, 10}, {10, 11}, {11, 12},
{0, 13}, {13, 14}, {14, 15}, {15, 16},
{0, 17}, {17, 18}, {18, 19}, {19, 20},
},
{
{255, 255, 255}, {255, 128, 0}, {255, 153, 255},
{102, 178, 255}, {255, 51, 51}, {0, 255, 0},
},
{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
{0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
};
return inst;
}
// n_links // n_links
// u0, v0, u1, v1, ..., un-1, vn-1 // u0, v0, u1, v1, ..., un-1, vn-1
// n_palette // n_palette
@ -86,6 +105,8 @@ inline Skeleton Skeleton::get(const std::string& path) {
return gSkeletonCoco(); return gSkeletonCoco();
} else if (path == "coco-wholebody") { } else if (path == "coco-wholebody") {
return gSkeletonCocoWholeBody(); return gSkeletonCocoWholeBody();
} else if (path == "coco-wholebody-hand") {
return gSkeletonCocoWholeBodyHand();
} }
std::ifstream ifs(path); std::ifstream ifs(path);
if (!ifs.is_open()) { if (!ifs.is_open()) {

View File

@ -1,4 +1,4 @@
# onnxruntime 支持情况 # onnxruntime Support
## Introduction of ONNX Runtime ## Introduction of ONNX Runtime

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.structures import SegDataSample from mmseg.structures import SegDataSample
from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import is_dynamic_shape from mmdeploy.utils import get_codebase_config, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
@ -47,4 +48,22 @@ def base_segmentor__forward(self,
for data_sample in data_samples: for data_sample in data_samples:
data_sample.set_field( data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo') name='img_shape', value=img_shape, field_type='metainfo')
return self.predict(inputs, data_samples) seg_logit = self.predict(inputs, data_samples)
# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
return seg_logit
ctx = FUNCTION_REWRITER.get_context()
with_argmax = get_codebase_config(ctx.cfg).get('with_argmax', True)
# deal with out_channels=1 with two classes
if seg_logit.shape[1] == 1:
seg_logit = seg_logit.sigmoid()
seg_pred = seg_logit > self.decode_head.threshold
seg_pred = seg_pred.to(torch.int64)
else:
seg_pred = __mark_seg_logit(seg_logit)
if with_argmax:
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
return seg_pred

View File

@ -17,7 +17,7 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
data_samples (SampleList): The seg data samples. data_samples (SampleList): The seg data samples.
Returns: Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. torch.Tensor: Output segmentation logits of shape [N, C, H, W].
""" """
batch_img_metas = [] batch_img_metas = []
for data_sample in data_samples: for data_sample in data_samples:
@ -28,5 +28,4 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
out = self.decode_head[i].forward(x, out) out = self.decode_head[i].forward(x, out)
seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas, seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas,
self.test_cfg) self.test_cfg)
seg_pred = seg_logit.argmax(dim=1, keepdim=True) return seg_logit
return seg_pred

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import get_codebase_config
@FUNCTION_REWRITER.register_rewriter( @FUNCTION_REWRITER.register_rewriter(
@ -18,24 +17,11 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
data_samples (SampleList): The seg data samples. data_samples (SampleList): The seg data samples.
Returns: Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. torch.Tensor: Output segmentation logits of shape [N, C, H, W].
""" """
batch_img_metas = [] batch_img_metas = []
for data_sample in data_samples: for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo) batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs) x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg) seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
ctx = FUNCTION_REWRITER.get_context()
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit return seg_logit
# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
return seg_logit
seg_logit = __mark_seg_logit(seg_logit)
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred

View File

@ -345,7 +345,11 @@ def _multiclass_nms_single(boxes: Tensor,
labels = labels[:, topk_inds, ...] labels = labels[:, topk_inds, ...]
if output_index: if output_index:
bbox_index = pre_topk_inds[None, topk_inds] bbox_index = box_inds.unsqueeze(0)
if pre_top_k > 0:
bbox_index = pre_topk_inds[None, box_inds]
if keep_top_k > 0:
bbox_index = bbox_index[:, topk_inds[:-1]]
return dets, labels, bbox_index return dets, labels, bbox_index
else: else:
return dets, labels return dets, labels

View File

@ -2,6 +2,7 @@
import mmengine import mmengine
import pytest import pytest
import torch import torch
from packaging import version
from mmdeploy.codebase import import_codebase from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, Task from mmdeploy.utils import Backend, Codebase, Task
@ -36,14 +37,31 @@ def test_encoderdecoder_predict(backend):
wrapped_model=wrapped_model, wrapped_model=wrapped_model,
model_inputs=rewrite_inputs, model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0)) rewrite_outputs = segmentor.postprocess_result(rewrite_outputs[0],
data_samples)
rewrite_outputs = rewrite_outputs[0].pred_sem_seg.data
assert torch.allclose(model_outputs, rewrite_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_basesegmentor_forward(backend): @pytest.mark.parametrize('with_argmax,use_sigmoid', [(True, False),
(False, True)])
def test_basesegmentor_forward(backend: Backend, with_argmax: bool,
use_sigmoid: bool):
check_backend(backend) check_backend(backend)
config_path = 'tests/test_codebase/test_mmseg/data/model.py'
model_cfg = mmengine.Config.fromfile(config_path)
if use_sigmoid:
import mmseg
if version.parse(mmseg.__version__) <= version.parse('1.0.0'):
pytest.skip('ignore mmseg<=1.0.0')
model_cfg.model.decode_head.num_classes = 2
model_cfg.model.decode_head.out_channels = 1
model_cfg.model.decode_head.threshold = 0.3
deploy_cfg = generate_mmseg_deploy_config(backend.value) deploy_cfg = generate_mmseg_deploy_config(backend.value)
task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg) deploy_cfg.codebase_config.with_argmax = with_argmax
task_processor = generate_mmseg_task_processor(
deploy_cfg=deploy_cfg, model_cfg=model_cfg)
segmentor = task_processor.build_pytorch_model() segmentor = task_processor.build_pytorch_model()
size = 256 size = 256
inputs = torch.randn(1, 3, size, size) inputs = torch.randn(1, 3, size, size)
@ -58,7 +76,11 @@ def test_basesegmentor_forward(backend):
wrapped_model=wrapped_model, wrapped_model=wrapped_model,
model_inputs=rewrite_inputs, model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0)) rewrite_outputs = rewrite_outputs[0]
if rewrite_outputs.shape[1] != 1:
rewrite_outputs = rewrite_outputs.argmax(dim=1, keepdim=True)
rewrite_outputs = rewrite_outputs.squeeze(0).to(model_outputs)
assert torch.allclose(model_outputs, rewrite_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME]) @pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])

View File

@ -24,7 +24,8 @@ def generate_mmseg_deploy_config(backend='onnxruntime'):
deploy_cfg = mmengine.Config( deploy_cfg = mmengine.Config(
dict( dict(
backend_config=dict(type=backend), backend_config=dict(type=backend),
codebase_config=dict(type='mmseg', task='Segmentation'), codebase_config=dict(
type='mmseg', task='Segmentation', with_argmax=False),
onnx_config=dict( onnx_config=dict(
type='onnx', type='onnx',
export_params=True, export_params=True,