mmdeploy/tests/test_mmcv/test_mmcv_cnn.py
q.yao 3f261e6d50
[Refactor] Refactor rewriter context for MMRazor (#1483)
* wip

* update rewriter

* Support all codebase

* update docs

* fix ssd

* rename qualname

* support torch.fx.wrap

* import by torch version

Co-authored-by: pppppM <gjf_mail@126.com>
2022-12-13 19:03:56 +08:00

33 lines
1.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine import Config
from mmdeploy.utils import Backend
from mmdeploy.utils.test import check_backend, get_rewrite_outputs
def test_multiheadattention_ncnn():
check_backend(Backend.NCNN)
from mmcv.cnn.bricks.transformer import MultiheadAttention
embed_dims, num_heads = 12, 2
model = MultiheadAttention(embed_dims, num_heads, batch_first=True)
query = torch.rand(1, 3, embed_dims)
deploy_cfg = Config(
dict(
onnx_config=dict(input_shape=None),
backend_config=dict(type=Backend.NCNN.value),
))
model_outputs = model(query)
rewrite_inputs = dict(query=query)
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=True)
if is_backend_output is None:
assert rewrite_outputs is not None
else:
assert torch.allclose(
model_outputs, rewrite_outputs[0], rtol=1e-03, atol=1e-05)