[Unittest] Add L2Norm unittest for default backend. (#43)
* add l2norm ut * fix yapf * fix test bugs * fix yapf * fix is_backend_outputpull/20/head^2
parent
81770e26c1
commit
9fd15e3843
|
@ -112,6 +112,15 @@ def get_fcos_head_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_l2norm_forward_model():
|
||||
"""L2Norm Neck Config."""
|
||||
from mmdet.models.necks.ssd_neck import L2Norm
|
||||
model = L2Norm(16)
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
def get_rpn_head_model():
|
||||
"""RPN Head Config."""
|
||||
test_cfg = mmcv.Config(
|
||||
|
@ -139,6 +148,43 @@ def get_single_roi_extractor():
|
|||
return model
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
def test_l2norm_forward(backend_type):
|
||||
check_backend(backend_type)
|
||||
l2norm_neck = get_l2norm_forward_model()
|
||||
l2norm_neck.cpu().eval()
|
||||
s = 128
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(input_shape=None)))
|
||||
feat = torch.rand(1, 16, s, s)
|
||||
model_outputs = [l2norm_neck.forward(feat)]
|
||||
wrapped_model = WrapModel(l2norm_neck, 'forward')
|
||||
rewrite_inputs = {
|
||||
'x': feat,
|
||||
}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if is_backend_output:
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs[0]):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
else:
|
||||
for model_output, rewrite_output in zip(model_outputs[0],
|
||||
rewrite_outputs[0]):
|
||||
model_output = model_output.squeeze().cpu().numpy()
|
||||
rewrite_output = rewrite_output.squeeze()
|
||||
assert np.allclose(
|
||||
model_output[0], rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
def test_get_bboxes_of_fcos_head_ncnn():
|
||||
backend_type = Backend.NCNN
|
||||
check_backend(backend_type)
|
||||
|
|
Loading…
Reference in New Issue