mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* support cascade (mask) rcnn * fix docstring * support SwinTransformer * move dense_head support to this branch * fix function names * fix part of uts of mmdet * fix for mmdet ut * fix det model cfg for ut * fix test_object_detection.py * fix mmdet object_detection_model.py * fix mmdet yolov3 ort ut * fix part of uts * fix cascade bbox head ut * fix cascade bbox head ut * remove useless ssd ncnn test * fix ncnn wrapper * fix openvino ut for reppoint head * fix openvino cascade mask rcnn * sync codes * support roll * remove unused pad * fix yolox * fix isort * fix lint * fix flake8 * reply for comments and fix failed ut * fix sdk_export in dump_info * fix temp hidden xlsx bugs * fix mmdet regression test * fix lint * fix timer * fix timecount side-effect * adapt profile.py for mmdet 2.0 * hardcode report.txt for T4 benchmark test: temp version * fix no-visualizer case * fix backend_model * fix android build * adapt new mmdet 2.0 0825 * fix new 2.0 * fix test_mmdet_structures * fix test_object_detection * fix codebase import * fix ut * fix all mmdet uts * fix det * fix mmdet trt * fix ncnn onnx optimize
51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from mmengine import Config
|
|
|
|
from mmdeploy.codebase import import_codebase
|
|
from mmdeploy.utils import Backend, Codebase
|
|
from mmdeploy.utils.test import (WrapFunction, check_backend,
|
|
get_rewrite_outputs)
|
|
|
|
import_codebase(Codebase.MMDET)
|
|
|
|
|
|
@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
|
|
def test_distance2bbox(backend_type: Backend):
|
|
check_backend(backend_type)
|
|
deploy_cfg = Config(
|
|
dict(
|
|
onnx_config=dict(output_names=None, input_shape=None),
|
|
backend_config=dict(type=backend_type.value, model_inputs=None),
|
|
codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
|
|
|
# wrap function to enable rewrite
|
|
def distance2bbox(*args, **kwargs):
|
|
import mmdet.structures.bbox.transforms
|
|
return mmdet.structures.bbox.transforms.distance2bbox(*args, **kwargs)
|
|
|
|
points = torch.rand(3, 2)
|
|
distance = torch.rand(3, 4)
|
|
original_outputs = distance2bbox(points, distance)
|
|
|
|
# wrap function to nn.Module, enable torch.onnx.export
|
|
wrapped_func = WrapFunction(distance2bbox)
|
|
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
|
wrapped_func,
|
|
model_inputs={
|
|
'points': points,
|
|
'distance': distance
|
|
},
|
|
deploy_cfg=deploy_cfg)
|
|
|
|
if is_backend_output:
|
|
model_output = original_outputs.squeeze().cpu().numpy()
|
|
rewrite_output = rewrite_outputs[0].squeeze()
|
|
assert np.allclose(
|
|
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
|
else:
|
|
assert rewrite_outputs is not None
|