fix(circleci): rewrite outputs is a list (#1319)

This commit is contained in:
tpoisonooo 2022-11-08 17:26:36 +08:00 committed by GitHub
parent 42fac7e004
commit 0dc25a27dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,7 +9,6 @@ from mmdeploy.codebase import import_codebase
from mmdeploy.utils import Backend, Codebase, Task, load_config
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
import_codebase(Codebase.MMDET3D)
try:
import_codebase(Codebase.MMDET3D)
except ImportError:
@ -61,7 +60,6 @@ def test_pillar_encoder(backend_type: Backend):
num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32)
coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32)
model_outputs = model.forward(features, num_points, coors)
model_outputs = [model_outputs]
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {
'features': features,
@ -74,11 +72,11 @@ def test_pillar_encoder(backend_type: Backend):
deploy_cfg=deploy_cfg)
if isinstance(rewrite_outputs, dict):
rewrite_outputs = rewrite_outputs['output']
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
if isinstance(rewrite_output, torch.Tensor):
rewrite_output = rewrite_output.cpu().numpy()
assert np.allclose(
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
if isinstance(rewrite_outputs, list):
rewrite_outputs = rewrite_outputs[0]
assert np.allclose(
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
@ -99,20 +97,16 @@ def test_pointpillars_scatter(backend_type: Backend):
voxel_features = torch.rand(16 * 16, 64) * 100
coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32)
model_outputs = model.forward_batch(voxel_features, coors, 1)
model_outputs = [model_outputs]
wrapped_model = WrapModel(model, 'forward_batch')
rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if isinstance(rewrite_outputs, dict):
rewrite_outputs = rewrite_outputs['output']
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
if isinstance(rewrite_output, torch.Tensor):
rewrite_output = rewrite_output.cpu().numpy()
assert np.allclose(
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
if isinstance(rewrite_outputs, list):
rewrite_outputs = rewrite_outputs[0]
assert np.allclose(
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
def get_centerpoint():