mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Merge remote-tracking branch 'upstream/main' into fix_ut
This commit is contained in:
commit
9e7a187f5c
@ -18,7 +18,7 @@ DEFINE_int32(flip, 0, "Set to 1 for flipping the input horizontally");
|
||||
DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable");
|
||||
|
||||
DEFINE_string(skeleton, "coco",
|
||||
R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody")");
|
||||
R"(Path to skeleton data or name of predefined skeletons: "coco", "coco-wholebody", "coco-wholebody-hand")");
|
||||
DEFINE_string(background, "default",
|
||||
R"(Output background, "default": original image, "black": black background)");
|
||||
|
||||
|
@ -73,6 +73,25 @@ const Skeleton& gSkeletonCocoWholeBody() {
|
||||
return inst;
|
||||
}
|
||||
|
||||
const Skeleton& gSkeletonCocoWholeBodyHand() {
|
||||
static const Skeleton inst{
|
||||
{
|
||||
{0, 1}, {1, 2}, {2, 3}, {3, 4},
|
||||
{0, 5}, {5, 6}, {6, 7}, {7, 8},
|
||||
{0, 9}, {9, 10}, {10, 11}, {11, 12},
|
||||
{0, 13}, {13, 14}, {14, 15}, {15, 16},
|
||||
{0, 17}, {17, 18}, {18, 19}, {19, 20},
|
||||
},
|
||||
{
|
||||
{255, 255, 255}, {255, 128, 0}, {255, 153, 255},
|
||||
{102, 178, 255}, {255, 51, 51}, {0, 255, 0},
|
||||
},
|
||||
{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
|
||||
{0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5,},
|
||||
};
|
||||
return inst;
|
||||
}
|
||||
|
||||
// n_links
|
||||
// u0, v0, u1, v1, ..., un-1, vn-1
|
||||
// n_palette
|
||||
@ -86,6 +105,8 @@ inline Skeleton Skeleton::get(const std::string& path) {
|
||||
return gSkeletonCoco();
|
||||
} else if (path == "coco-wholebody") {
|
||||
return gSkeletonCocoWholeBody();
|
||||
} else if (path == "coco-wholebody-hand") {
|
||||
return gSkeletonCocoWholeBodyHand();
|
||||
}
|
||||
std::ifstream ifs(path);
|
||||
if (!ifs.is_open()) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
# onnxruntime 支持情况
|
||||
# onnxruntime Support
|
||||
|
||||
## Introduction of ONNX Runtime
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmseg.structures import SegDataSample
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.utils import is_dynamic_shape
|
||||
from mmdeploy.utils import get_codebase_config, is_dynamic_shape
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
@ -47,4 +48,22 @@ def base_segmentor__forward(self,
|
||||
for data_sample in data_samples:
|
||||
data_sample.set_field(
|
||||
name='img_shape', value=img_shape, field_type='metainfo')
|
||||
return self.predict(inputs, data_samples)
|
||||
seg_logit = self.predict(inputs, data_samples)
|
||||
|
||||
# mark seg_head
|
||||
@mark('decode_head', outputs=['output'])
|
||||
def __mark_seg_logit(seg_logit):
|
||||
return seg_logit
|
||||
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
with_argmax = get_codebase_config(ctx.cfg).get('with_argmax', True)
|
||||
# deal with out_channels=1 with two classes
|
||||
if seg_logit.shape[1] == 1:
|
||||
seg_logit = seg_logit.sigmoid()
|
||||
seg_pred = seg_logit > self.decode_head.threshold
|
||||
seg_pred = seg_pred.to(torch.int64)
|
||||
else:
|
||||
seg_pred = __mark_seg_logit(seg_logit)
|
||||
if with_argmax:
|
||||
seg_pred = seg_pred.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
|
@ -17,7 +17,7 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
||||
data_samples (SampleList): The seg data samples.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
torch.Tensor: Output segmentation logits of shape [N, C, H, W].
|
||||
"""
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
@ -28,5 +28,4 @@ def cascade_encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
seg_logit = self.decode_head[-1].predict(x, out, batch_img_metas,
|
||||
self.test_cfg)
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
return seg_logit
|
||||
|
@ -1,6 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmdeploy.core import FUNCTION_REWRITER, mark
|
||||
from mmdeploy.utils import get_codebase_config
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
@ -18,24 +17,11 @@ def encoder_decoder__predict(self, inputs, data_samples, **kwargs):
|
||||
data_samples (SampleList): The seg data samples.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output segmentation map pf shape [N, 1, H, W].
|
||||
torch.Tensor: Output segmentation logits of shape [N, C, H, W].
|
||||
"""
|
||||
batch_img_metas = []
|
||||
for data_sample in data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
x = self.extract_feat(inputs)
|
||||
seg_logit = self.decode_head.predict(x, batch_img_metas, self.test_cfg)
|
||||
|
||||
ctx = FUNCTION_REWRITER.get_context()
|
||||
if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
|
||||
return seg_logit
|
||||
|
||||
# mark seg_head
|
||||
@mark('decode_head', outputs=['output'])
|
||||
def __mark_seg_logit(seg_logit):
|
||||
return seg_logit
|
||||
|
||||
seg_logit = __mark_seg_logit(seg_logit)
|
||||
|
||||
seg_pred = seg_logit.argmax(dim=1, keepdim=True)
|
||||
return seg_pred
|
||||
|
@ -345,7 +345,11 @@ def _multiclass_nms_single(boxes: Tensor,
|
||||
labels = labels[:, topk_inds, ...]
|
||||
|
||||
if output_index:
|
||||
bbox_index = pre_topk_inds[None, topk_inds]
|
||||
bbox_index = box_inds.unsqueeze(0)
|
||||
if pre_top_k > 0:
|
||||
bbox_index = pre_topk_inds[None, box_inds]
|
||||
if keep_top_k > 0:
|
||||
bbox_index = bbox_index[:, topk_inds[:-1]]
|
||||
return dets, labels, bbox_index
|
||||
else:
|
||||
return dets, labels
|
||||
|
@ -2,6 +2,7 @@
|
||||
import mmengine
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, Task
|
||||
@ -36,14 +37,31 @@ def test_encoderdecoder_predict(backend):
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
|
||||
rewrite_outputs = segmentor.postprocess_result(rewrite_outputs[0],
|
||||
data_samples)
|
||||
rewrite_outputs = rewrite_outputs[0].pred_sem_seg.data
|
||||
assert torch.allclose(model_outputs, rewrite_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
def test_basesegmentor_forward(backend):
|
||||
@pytest.mark.parametrize('with_argmax,use_sigmoid', [(True, False),
|
||||
(False, True)])
|
||||
def test_basesegmentor_forward(backend: Backend, with_argmax: bool,
|
||||
use_sigmoid: bool):
|
||||
check_backend(backend)
|
||||
config_path = 'tests/test_codebase/test_mmseg/data/model.py'
|
||||
model_cfg = mmengine.Config.fromfile(config_path)
|
||||
if use_sigmoid:
|
||||
import mmseg
|
||||
if version.parse(mmseg.__version__) <= version.parse('1.0.0'):
|
||||
pytest.skip('ignore mmseg<=1.0.0')
|
||||
model_cfg.model.decode_head.num_classes = 2
|
||||
model_cfg.model.decode_head.out_channels = 1
|
||||
model_cfg.model.decode_head.threshold = 0.3
|
||||
deploy_cfg = generate_mmseg_deploy_config(backend.value)
|
||||
task_processor = generate_mmseg_task_processor(deploy_cfg=deploy_cfg)
|
||||
deploy_cfg.codebase_config.with_argmax = with_argmax
|
||||
task_processor = generate_mmseg_task_processor(
|
||||
deploy_cfg=deploy_cfg, model_cfg=model_cfg)
|
||||
segmentor = task_processor.build_pytorch_model()
|
||||
size = 256
|
||||
inputs = torch.randn(1, 3, size, size)
|
||||
@ -58,7 +76,11 @@ def test_basesegmentor_forward(backend):
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
assert torch.allclose(model_outputs, rewrite_outputs[0].squeeze(0))
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
if rewrite_outputs.shape[1] != 1:
|
||||
rewrite_outputs = rewrite_outputs.argmax(dim=1, keepdim=True)
|
||||
rewrite_outputs = rewrite_outputs.squeeze(0).to(model_outputs)
|
||||
assert torch.allclose(model_outputs, rewrite_outputs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME])
|
||||
|
@ -24,7 +24,8 @@ def generate_mmseg_deploy_config(backend='onnxruntime'):
|
||||
deploy_cfg = mmengine.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend),
|
||||
codebase_config=dict(type='mmseg', task='Segmentation'),
|
||||
codebase_config=dict(
|
||||
type='mmseg', task='Segmentation', with_argmax=False),
|
||||
onnx_config=dict(
|
||||
type='onnx',
|
||||
export_params=True,
|
||||
|
Loading…
x
Reference in New Issue
Block a user