diff --git a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py index 48b41ebb3..5f55d3d05 100644 --- a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py +++ b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py @@ -9,7 +9,6 @@ from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase, Task, load_config from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs -import_codebase(Codebase.MMDET3D) try: import_codebase(Codebase.MMDET3D) except ImportError: @@ -61,7 +60,6 @@ def test_pillar_encoder(backend_type: Backend): num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32) coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32) model_outputs = model.forward(features, num_points, coors) - model_outputs = [model_outputs] wrapped_model = WrapModel(model, 'forward') rewrite_inputs = { 'features': features, @@ -74,11 +72,11 @@ def test_pillar_encoder(backend_type: Backend): deploy_cfg=deploy_cfg) if isinstance(rewrite_outputs, dict): rewrite_outputs = rewrite_outputs['output'] - for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): - if isinstance(rewrite_output, torch.Tensor): - rewrite_output = rewrite_output.cpu().numpy() - assert np.allclose( - model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03) + if isinstance(rewrite_outputs, list): + rewrite_outputs = rewrite_outputs[0] + + assert np.allclose( + model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03) @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) @@ -99,20 +97,16 @@ def test_pointpillars_scatter(backend_type: Backend): voxel_features = torch.rand(16 * 16, 64) * 100 coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32) model_outputs = model.forward_batch(voxel_features, coors, 1) - model_outputs = [model_outputs] wrapped_model = WrapModel(model, 'forward_batch') rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors} rewrite_outputs, is_backend_output = get_rewrite_outputs( wrapped_model=wrapped_model, model_inputs=rewrite_inputs, deploy_cfg=deploy_cfg) - if isinstance(rewrite_outputs, dict): - rewrite_outputs = rewrite_outputs['output'] - for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): - if isinstance(rewrite_output, torch.Tensor): - rewrite_output = rewrite_output.cpu().numpy() - assert np.allclose( - model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03) + if isinstance(rewrite_outputs, list): + rewrite_outputs = rewrite_outputs[0] + assert np.allclose( + model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03) def get_centerpoint():