mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix(circleci): rewrite outputs is a list (#1319)
This commit is contained in:
parent
42fac7e004
commit
0dc25a27dd
@ -9,7 +9,6 @@ from mmdeploy.codebase import import_codebase
|
|||||||
from mmdeploy.utils import Backend, Codebase, Task, load_config
|
from mmdeploy.utils import Backend, Codebase, Task, load_config
|
||||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||||
|
|
||||||
import_codebase(Codebase.MMDET3D)
|
|
||||||
try:
|
try:
|
||||||
import_codebase(Codebase.MMDET3D)
|
import_codebase(Codebase.MMDET3D)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -61,7 +60,6 @@ def test_pillar_encoder(backend_type: Backend):
|
|||||||
num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32)
|
num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32)
|
||||||
coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32)
|
coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32)
|
||||||
model_outputs = model.forward(features, num_points, coors)
|
model_outputs = model.forward(features, num_points, coors)
|
||||||
model_outputs = [model_outputs]
|
|
||||||
wrapped_model = WrapModel(model, 'forward')
|
wrapped_model = WrapModel(model, 'forward')
|
||||||
rewrite_inputs = {
|
rewrite_inputs = {
|
||||||
'features': features,
|
'features': features,
|
||||||
@ -74,11 +72,11 @@ def test_pillar_encoder(backend_type: Backend):
|
|||||||
deploy_cfg=deploy_cfg)
|
deploy_cfg=deploy_cfg)
|
||||||
if isinstance(rewrite_outputs, dict):
|
if isinstance(rewrite_outputs, dict):
|
||||||
rewrite_outputs = rewrite_outputs['output']
|
rewrite_outputs = rewrite_outputs['output']
|
||||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
if isinstance(rewrite_outputs, list):
|
||||||
if isinstance(rewrite_output, torch.Tensor):
|
rewrite_outputs = rewrite_outputs[0]
|
||||||
rewrite_output = rewrite_output.cpu().numpy()
|
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
|
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
||||||
@ -99,20 +97,16 @@ def test_pointpillars_scatter(backend_type: Backend):
|
|||||||
voxel_features = torch.rand(16 * 16, 64) * 100
|
voxel_features = torch.rand(16 * 16, 64) * 100
|
||||||
coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32)
|
coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32)
|
||||||
model_outputs = model.forward_batch(voxel_features, coors, 1)
|
model_outputs = model.forward_batch(voxel_features, coors, 1)
|
||||||
model_outputs = [model_outputs]
|
|
||||||
wrapped_model = WrapModel(model, 'forward_batch')
|
wrapped_model = WrapModel(model, 'forward_batch')
|
||||||
rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors}
|
rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors}
|
||||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||||
wrapped_model=wrapped_model,
|
wrapped_model=wrapped_model,
|
||||||
model_inputs=rewrite_inputs,
|
model_inputs=rewrite_inputs,
|
||||||
deploy_cfg=deploy_cfg)
|
deploy_cfg=deploy_cfg)
|
||||||
if isinstance(rewrite_outputs, dict):
|
if isinstance(rewrite_outputs, list):
|
||||||
rewrite_outputs = rewrite_outputs['output']
|
rewrite_outputs = rewrite_outputs[0]
|
||||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
assert np.allclose(
|
||||||
if isinstance(rewrite_output, torch.Tensor):
|
model_outputs.shape, rewrite_outputs.shape, rtol=1e-03, atol=1e-03)
|
||||||
rewrite_output = rewrite_output.cpu().numpy()
|
|
||||||
assert np.allclose(
|
|
||||||
model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03)
|
|
||||||
|
|
||||||
|
|
||||||
def get_centerpoint():
|
def get_centerpoint():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user