mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix(circleci): rewrite outputs is a list (#1319)
This commit is contained in:
parent
42fac7e004
commit
0dc25a27dd
@ -9,7 +9,6 @@ from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.utils import Backend, Codebase, Task, load_config
|
||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||
|
||||
import_codebase(Codebase.MMDET3D)
|
||||
try:
|
||||
import_codebase(Codebase.MMDET3D)
|
||||
except ImportError:
|
||||
@ -61,7 +60,6 @@ def test_pillar_encoder(backend_type: Backend):
|
||||
num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32)
|
||||
coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32)
|
||||
model_outputs = model.forward(features, num_points, coors)
|
||||
model_outputs = [model_outputs]
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {
|
||||
'features': features,
|
||||
@ -74,11 +72,11 @@ def test_pillar_encoder(backend_type: Backend):
|
||||
deploy_cfg=deploy_cfg)
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = rewrite_outputs['output']
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
if isinstance(rewrite_output, torch.Tensor):
|
||||
rewrite_output = rewrite_output.cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
|
||||
if isinstance(rewrite_outputs, list):
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
|
||||
assert np.allclose(
|
||||
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||
@ -99,20 +97,16 @@ def test_pointpillars_scatter(backend_type: Backend):
|
||||
voxel_features = torch.rand(16 * 16, 64) * 100
|
||||
coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32)
|
||||
model_outputs = model.forward_batch(voxel_features, coors, 1)
|
||||
model_outputs = [model_outputs]
|
||||
wrapped_model = WrapModel(model, 'forward_batch')
|
||||
rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = rewrite_outputs['output']
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
if isinstance(rewrite_output, torch.Tensor):
|
||||
rewrite_output = rewrite_output.cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
|
||||
if isinstance(rewrite_outputs, list):
|
||||
rewrite_outputs = rewrite_outputs[0]
|
||||
assert np.allclose(
|
||||
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
|
||||
|
||||
|
||||
def get_centerpoint():
|
||||
|
Loading…
x
Reference in New Issue
Block a user