Merge remote-tracking branch 'upstream/main' into fix_ut

This commit is contained in:
RunningLeon 2023-06-25 15:11:39 +08:00
commit 9e7a187f5c
10 changed files with 82 additions and 30 deletions

View File

@ -18,7 +18,7 @@ DEFINE_int32(flip, 0, "Set to 1 for flipping the input horizontally");
DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable");
DEFINE_string(skeleton, "coco",
R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody")");
R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody", "coco-wholebody-hand")");
DEFINE_string(background, "default",
R"(Output background, "default": original image, "black": black background)");

View File

@ -73,6 +73,25 @@ const Skeleton& gSkeletonCocoWholeBody() {
return inst;
}
const Skeleton& gSkeletonCocoWholeBodyHand() {
static const Skeleton inst{
{
{0, 1}, {1, 2}, {2, 3}, {3, 4},
{0, 5}, {5, 6}, {6, 7}, {7, 8},
{0, 9}, {9, 10}, {10, 11}, {11, 12},
{0, 13}, {13, 14}, {14, 15}, {15, 16},
{0, 17}, {17, 18}, {18, 19}, {19, 20},
},
{
{255, 255, 255}, {255, 128, 0}, {255, 153, 255},
{102, 178, 255}, {255, 51, 51}, {0, 255, 0},
},
{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
{0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
};
return inst;
}
// n_links
// u0, v0, u1, v1, ..., un-1, vn-1
// n_palette
@ -86,6 +105,8 @@ inline Skeleton Skeleton::get(const std::string& path) {
return gSkeletonCoco();
} else if (path == "coco-wholebody") {
return gSkeletonCocoWholeBody();
} else if (path == "coco-wholebody-hand") {
return gSkeletonCocoWholeBodyHand();
}
std::ifstream ifs(path);
if (!ifs.is_open()) {

View File

@ -1,4 +1,4 @@
# onnxruntime 支持情况
# onnxruntime Support
## Introduction of ONNX Runtime

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.structures import SegDataSample
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import is_dynamic_shape
from mmdeploy.utils import get_codebase_config, is_dynamic_shape
@FUNCTION_REWRITER.register_rewriter(
@ -47,4 +48,22 @@ def base_segmentor__forward(self,
for data_sample in data_samples:
data_sample.set_field(
name='img_shape', value=img_shape, field_type='metainfo')
return self.predict(inputs, data_samples)
seg_logit = self.predict(inputs, data_samples)
# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
return seg_logit
ctx = FUNCTION_REWRITER.get_context()
with_argmax = get_codebase_config(ctx.cfg).get('with_argmax', True)
# deal with out_channels=1 with two classes
if seg_logit.shape[1] == 1:
seg_logit = seg_logit.sigmoid()
seg_pred = seg_logit > self.decode_head.threshold
seg_pred = seg_pred.to(torch.int64)
else:
seg_pred = __mark_seg_logit(seg_logit)
if with_argmax:
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
return seg_pred

View File

@ -17,7 +17,7 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
data_samples (SampleList): The seg data samples.
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
torch.Tensor: Output segmentation logits of shape [N, C, H, W].
"""
batch_img_metas = []
for data_sample in data_samples:
@ -28,5 +28,4 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
out = self.decode_head[i].forward(x, out)
seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas,
self.test_cfg)
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred
return seg_logit

View File

@ -1,6 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.utils import get_codebase_config
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
@ -18,24 +17,11 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
data_samples (SampleList): The seg data samples.
Returns:
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
torch.Tensor: Output segmentation logits of shape [N, C, H, W].
"""
batch_img_metas = []
for data_sample in data_samples:
batch_img_metas.append(data_sample.metainfo)
x = self.extract_feat(inputs)
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
ctx = FUNCTION_REWRITER.get_context()
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
return seg_logit
# mark seg_head
@mark('decode_head', outputs=['output'])
def __mark_seg_logit(seg_logit):
return seg_logit
seg_logit = __mark_seg_logit(seg_logit)
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
return seg_pred

View File

@ -345,7 +345,11 @@ def _multiclass_nms_single(boxes: Tensor,
labels = labels[:, topk_inds, ...]
if output_index:
bbox_index = pre_topk_inds[None, topk_inds]
bbox_index = box_inds.unsqueeze(0)
if pre_top_k > 0:
bbox_index = pre_topk_inds[None, box_inds]
if keep_top_k > 0:
bbox_index = bbox_index[:, topk_inds[:-1]]
return dets, labels, bbox_index
else:
return dets, labels

View File

@ -2,6 +2,7 @@
import mmengine
import pytest
import torch
from packaging import version
from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, Task
@ -36,14 +37,31 @@ def test_encoderdecoder_predict(backend):
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
rewrite_outputs = segmentor.postprocess_result(rewrite_outputs[0],
data_samples)
rewrite_outputs = rewrite_outputs[0].pred_sem_seg.data
assert torch.allclose(model_outputs, rewrite_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
def test_basesegmentor_forward(backend):
@pytest.mark.parametrize('with_argmax,use_sigmoid', [(True, False),
(False, True)])
def test_basesegmentor_forward(backend: Backend, with_argmax: bool,
use_sigmoid: bool):
check_backend(backend)
config_path = 'tests/test_codebase/test_mmseg/data/model.py'
model_cfg = mmengine.Config.fromfile(config_path)
if use_sigmoid:
import mmseg
if version.parse(mmseg.__version__) <= version.parse('1.0.0'):
pytest.skip('ignore mmseg<=1.0.0')
model_cfg.model.decode_head.num_classes = 2
model_cfg.model.decode_head.out_channels = 1
model_cfg.model.decode_head.threshold = 0.3
deploy_cfg = generate_mmseg_deploy_config(backend.value)
task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
deploy_cfg.codebase_config.with_argmax = with_argmax
task_processor = generate_mmseg_task_processor(
deploy_cfg=deploy_cfg, model_cfg=model_cfg)
segmentor = task_processor.build_pytorch_model()
size = 256
inputs = torch.randn(1, 3, size, size)
@ -58,7 +76,11 @@ def test_basesegmentor_forward(backend):
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
rewrite_outputs = rewrite_outputs[0]
if rewrite_outputs.shape[1] != 1:
rewrite_outputs = rewrite_outputs.argmax(dim=1, keepdim=True)
rewrite_outputs = rewrite_outputs.squeeze(0).to(model_outputs)
assert torch.allclose(model_outputs, rewrite_outputs)
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])

View File

@ -24,7 +24,8 @@ def generate_mmseg_deploy_config(backend='onnxruntime'):
deploy_cfg = mmengine.Config(
dict(
backend_config=dict(type=backend),
codebase_config=dict(type='mmseg', task='Segmentation'),
codebase_config=dict(
type='mmseg', task='Segmentation', with_argmax=False),
onnx_config=dict(
type='onnx',
export_params=True,