mmdeploy/tests/test_ops/test_ops.py

44 lines
1.4 KiB
Python
Raw Normal View History

import pytest
import torch
from mmdeploy.utils.test import WrapFunction
from .utils import TestOnnxRTExporter, TestTensorRTExporter
TEST_TENSORRT = TestTensorRTExporter()
TEST_ONNXRT = TestOnnxRTExporter()
ALL_BACKEND = [TEST_TENSORRT, TEST_ONNXRT]
@pytest.mark.parametrize('backend', ALL_BACKEND)
@pytest.mark.parametrize('pool_h,pool_w,spatial_scale,sampling_ratio',
[(2, 2, 1.0, 2), (4, 4, 2.0, 4)])
def test_roi_align(backend,
pool_h,
pool_w,
spatial_scale,
sampling_ratio,
inputs=None,
work_dir=None):
backend.check_env()
# TODO: check if mmcv-full is installed
from mmcv.ops import roi_align
def wrapped_function(torch_input, torch_rois):
return roi_align(torch_input, torch_rois, (pool_w, pool_h),
spatial_scale, sampling_ratio, 'avg', True)
wrapped_model = WrapFunction(wrapped_function)
if not inputs:
input = torch.rand(1, 1, 16, 16, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
else:
input = torch.tensor(inputs[0], dtype=torch.float32)
single_roi = torch.tensor(input[1], dtype=torch.float32)
backend.run_and_validate(
wrapped_model, [input, single_roi],
'roi_align',
input_names=['input', 'rois'],
output_names=['roi_feat'],
work_dir=work_dir)