From 4d9e20960dad41565ba16feb4027fa53a35d860b Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 28 Jun 2022 09:41:44 +0800 Subject: [PATCH] [Enhancement] Fix ncnn unittest (#626) * optmize-csp-darknet * replace floordiv to torch.div * update csp_darknet default implement * fix test --- mmdeploy/codebase/mmdet/models/backbones.py | 19 +++++++++++++++++++ .../test_mmdet/test_mmdet_models.py | 11 +++++------ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/mmdeploy/codebase/mmdet/models/backbones.py b/mmdeploy/codebase/mmdet/models/backbones.py index 65b672785..76afc176b 100644 --- a/mmdeploy/codebase/mmdet/models/backbones.py +++ b/mmdeploy/codebase/mmdet/models/backbones.py @@ -4,6 +4,25 @@ import torch from mmdeploy.core import FUNCTION_REWRITER +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.backbones.csp_darknet.Focus.forward') +def focus__forward__default(ctx, self, x): + """Rewrite forward function of Focus class. + + Replace slice with transpose. + """ + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + B, C, H, W = x.shape + x = x.reshape(B, C, -1, 2, W) + x = x.reshape(B, C, x.shape[2], 2, -1, 2) + half_H = x.shape[2] + half_W = x.shape[4] + x = x.permute(0, 5, 3, 1, 2, 4) + x = x.reshape(B, C * 4, half_H, half_W) + + return self.conv(x) + + @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.backbones.csp_darknet.Focus.forward', backend='ncnn') diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index def6a6283..9a8c075c8 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -201,8 +201,8 @@ def get_gfl_head_model(): return model -def test_focus_forward_ncnn(): - backend_type = Backend.NCNN +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN]) +def test_focus_forward(backend_type): check_backend(backend_type) focus_model = get_focus_backbone_model() focus_model.cpu().eval() @@ -222,11 +222,10 @@ def test_focus_forward_ncnn(): wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) - for model_output, rewrite_output in zip(model_outputs[0], - rewrite_outputs[0]): - model_output = model_output.squeeze().cpu().numpy() + for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs): + model_output = model_output.squeeze() rewrite_output = rewrite_output.squeeze() - assert np.allclose( + torch.testing.assert_allclose( model_output, rewrite_output, rtol=1e-03, atol=1e-05)