From 4f49763c28e8acab6802685cc5e562ec8e2cc42d Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 25 May 2022 09:52:42 +0800 Subject: [PATCH] fix mmseg twice resize (#480) * fix mmseg twich resize * remove comment --- .../mmseg/models/segmentors/encoder_decoder.py | 14 +------------- .../test_codebase/test_mmseg/test_mmseg_models.py | 6 ++++-- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py index bca614ae8..0ed9ace84 100644 --- a/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py +++ b/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn.functional as F -from mmseg.ops import resize from mmdeploy.core import FUNCTION_REWRITER -from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( @@ -25,16 +23,6 @@ def encoder_decoder__simple_test(ctx, self, img, img_meta, **kwargs): torch.Tensor: Output segmentation map pf shape [N, 1, H, W]. """ seg_logit = self.encode_decode(img, img_meta) - seg_logit = resize( - input=seg_logit, - size=img_meta['img_shape'], - mode='bilinear', - align_corners=self.align_corners) seg_logit = F.softmax(seg_logit, dim=1) - seg_pred = seg_logit.argmax(dim=1) - # our inference backend only support 4D output - shape = seg_pred.shape - if not is_dynamic_shape(ctx.cfg): - shape = [int(_) for _ in shape] - seg_pred = seg_pred.view(shape[0], 1, shape[1], shape[2]) + seg_pred = seg_logit.argmax(dim=1, keepdim=True) return seg_pred diff --git a/tests/test_codebase/test_mmseg/test_mmseg_models.py b/tests/test_codebase/test_mmseg/test_mmseg_models.py index dfcd5b4cd..d5f228593 100644 --- a/tests/test_codebase/test_mmseg/test_mmseg_models.py +++ b/tests/test_codebase/test_mmseg/test_mmseg_models.py @@ -93,7 +93,8 @@ def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10): return mm_inputs -@pytest.mark.parametrize('backend', [Backend.ONNXRUNTIME, Backend.OPENVINO]) +@pytest.mark.parametrize('backend', + [Backend.ONNXRUNTIME, Backend.OPENVINO, Backend.NCNN]) def test_encoderdecoder_simple_test(backend): check_backend(backend) segmentor = get_model() @@ -109,7 +110,8 @@ def test_encoderdecoder_simple_test(backend): num_classes = segmentor.decode_head[-1].num_classes else: num_classes = segmentor.decode_head.num_classes - mm_inputs = _demo_mm_inputs(num_classes=num_classes) + mm_inputs = _demo_mm_inputs( + input_shape=(1, 3, 32, 32), num_classes=num_classes) imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') model_inputs = {'img': imgs, 'img_meta': img_metas}