[Enhancement] Fix ncnn unittest (#626)

* optmize-csp-darknet

* replace floordiv to torch.div

* update csp_darknet default implement

* fix test
pull/646/head^2
q.yao 2022-06-28 09:41:44 +08:00 committed by GitHub
parent 05cafab723
commit 4d9e20960d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 6 deletions

View File

@ -4,6 +4,25 @@ import torch
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward')
def focus__forward__default(ctx, self, x):
"""Rewrite forward function of Focus class.
Replace slice with transpose.
"""
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
B, C, H, W = x.shape
x = x.reshape(B, C, -1, 2, W)
x = x.reshape(B, C, x.shape[2], 2, -1, 2)
half_H = x.shape[2]
half_W = x.shape[4]
x = x.permute(0, 5, 3, 1, 2, 4)
x = x.reshape(B, C * 4, half_H, half_W)
return self.conv(x)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.csp_darknet.Focus.forward',
backend='ncnn')

View File

@ -201,8 +201,8 @@ def get_gfl_head_model():
return model
def test_focus_forward_ncnn():
backend_type = Backend.NCNN
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME, Backend.NCNN])
def test_focus_forward(backend_type):
check_backend(backend_type)
focus_model = get_focus_backbone_model()
focus_model.cpu().eval()
@ -222,11 +222,10 @@ def test_focus_forward_ncnn():
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs[0]):
model_output = model_output.squeeze().cpu().numpy()
for model_output, rewrite_output in zip(model_outputs[0], rewrite_outputs):
model_output = model_output.squeeze()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
torch.testing.assert_allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)