diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp index a98b782d1..8440bb621 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp @@ -26,6 +26,21 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { } const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override {} + + virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override { + return 0; + } + + virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, + nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {} + + virtual void detachFromContext() TRT_NOEXCEPT override {} + protected: const std::string mLayerName; std::string mNamespace; @@ -34,10 +49,8 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { protected: // To prevent compiler warnings. using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; - using nvinfer1::IPluginV2DynamicExt::configurePlugin; using nvinfer1::IPluginV2DynamicExt::enqueue; using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; - using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; using nvinfer1::IPluginV2DynamicExt::supportsFormat; #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp new file mode 100644 index 000000000..1850fbfc1 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp @@ -0,0 +1,154 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "trt_grid_priors.hpp" + +#include + +#include + +#include "trt_grid_priors_kernel.hpp" +#include "trt_serialize.hpp" + +using namespace nvinfer1; + +namespace mmdeploy { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"GridPriorsTRT"}; +} // namespace + +GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride) + : TRTPluginBase(name), mStride(stride) {} + +GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mStride); +} +GridPriorsTRT::~GridPriorsTRT() {} + +nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT { + GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + // input[0] == base_anchor + // input[1] == empty_h + // input[2] == empty_w + + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + auto area = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); + ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); + ret.d[1] = exprBuilder.constant(4); + + return ret; +} + +bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, + int nbInputs, int nbOutputs) TRT_NOEXCEPT { + if (pos == 0) { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } else if (pos - nbInputs == 0) { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } else { + return true; + } +} + +int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workSpace, + cudaStream_t stream) TRT_NOEXCEPT { + int num_base_anchors = inputDesc[0].dims.d[0]; + int feat_h = inputDesc[1].dims.d[0]; + int feat_w = inputDesc[2].dims.d[0]; + + const void *base_anchor = inputs[0]; + void *output = outputs[0]; + + auto data_type = inputDesc[0].type; + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + trt_grid_priors_impl((float *)base_anchor, (float *)output, num_base_anchors, feat_w, + feat_h, mStride.d[0], mStride.d[1], stream); + break; + default: + return 1; + } + + return 0; +} + +nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } + +int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); } + +void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mStride); + ; +} + +////////////////////// creator ///////////////////////////// + +GridPriorsTRTCreator::GridPriorsTRTCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } + +const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } + +nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + int stride_w = 1; + int stride_h = 1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("stride_w") == 0) { + stride_w = static_cast(fc->fields[i].data)[0]; + } + if (field_name.compare("stride_h") == 0) { + stride_h = static_cast(fc->fields[i].data)[0]; + } + } + nvinfer1::Dims stride{2, {stride_w, stride_h}}; + + GridPriorsTRT *plugin = new GridPriorsTRT(name, stride); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) TRT_NOEXCEPT { + auto plugin = new GridPriorsTRT(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} +REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); +} // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp new file mode 100644 index 000000000..0036f6258 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp @@ -0,0 +1,66 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef TRT_GRID_PRIORS_HPP +#define TRT_GRID_PRIORS_HPP +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" + +namespace mmdeploy { +class GridPriorsTRT : public TRTPluginBase { + public: + GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride); + + GridPriorsTRT(const std::string name, const void *data, size_t length); + + GridPriorsTRT() = delete; + + ~GridPriorsTRT() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, + int nbInputs, nvinfer1::IExprBuilder &exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + const char *getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + + cublasHandle_t m_cublas_handle; +}; + +class GridPriorsTRTCreator : public TRTPluginCreatorBase { + public: + GridPriorsTRTCreator(); + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmdeploy +#endif // TRT_GRID_PRIORS_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu new file mode 100644 index 000000000..72c33d179 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu @@ -0,0 +1,43 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include + +#include "common_cuda_helper.hpp" +#include "trt_grid_priors_kernel.hpp" +#include "trt_plugin_helper.hpp" + +template +__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output, + int num_base_anchors, int feat_w, int feat_h, int stride_w, + int stride_h) { + // load base anchor into shared memory. + extern __shared__ scalar_t shared_base_anchor[]; + for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) { + shared_base_anchor[i] = base_anchor[i]; + } + __syncthreads(); + + CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) { + const int a_offset = (index % num_base_anchors) << 2; + const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); + const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); + + auto out_start = output + index * 4; + out_start[0] = shared_base_anchor[a_offset] + w; + out_start[1] = shared_base_anchor[a_offset + 1] + h; + out_start[2] = shared_base_anchor[a_offset + 2] + w; + out_start[3] = shared_base_anchor[a_offset + 3] + h; + } +} + +template +void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, + int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) { + trt_grid_priors_kernel<<>>( + base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w, + (int)stride_h); +} + +template void trt_grid_priors_impl(const float* base_anchor, float* output, + int num_base_anchors, int feat_w, int feat_h, + int stride_w, int stride_h, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp new file mode 100644 index 000000000..77cef58c5 --- /dev/null +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp @@ -0,0 +1,10 @@ +// Copyright (c) OpenMMLab. All rights reserved +#ifndef TRT_GRID_PRIORS_KERNEL_HPP +#define TRT_GRID_PRIORS_KERNEL_HPP +#include + +template +void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, + int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); + +#endif diff --git a/docs/en/ops/tensorrt.md b/docs/en/ops/tensorrt.md index d1feae59c..35c196940 100644 --- a/docs/en/ops/tensorrt.md +++ b/docs/en/ops/tensorrt.md @@ -51,6 +51,12 @@ - [Inputs](#inputs-7) - [Outputs](#outputs-7) - [Type Constraints](#type-constraints-7) + - [GridPriorsTRT](#gridpriorstrt) + - [Description](#description-8) + - [Parameters](#parameters-8) + - [Inputs](#inputs-8) + - [Outputs](#outputs-8) + - [Type Constraints](#type-constraints-8) @@ -363,3 +369,39 @@ Batched rotated NMS with a fixed number of output bounding boxes. #### Type Constraints - T:tensor(float32, Linear) + +### GridPriorsTRT + +#### Description + +Generate the anchors for object detection task. + +#### Parameters + +| Type | Parameter | Description | +| ----- | ---------- | --------------------------------- | +| `int` | `stride_w` | The stride of the feature width. | +| `int` | `stride_h` | The stride of the feature height. | + +#### Inputs + +
+
inputs[0]: T
+
The base anchors; 2-D tensor with shape [num_base_anchor, 4].
+
inputs[1]: TAny
+
height provider; 1-D tensor with shape [featmap_height]. The data will never been used.
+
inputs[2]: TAny
+
width provider; 1-D tensor with shape [featmap_width]. The data will never been used.
+
+ +#### Outputs + +
+
outputs[0]: T
+
output anchors; 2-D tensor of shape (num_base_anchor*featmap_height*featmap_widht, 4).
+
+ +#### Type Constraints + +- T:tensor(float32, Linear) +- TAny: Any diff --git a/mmdeploy/codebase/mmdet/core/__init__.py b/mmdeploy/codebase/mmdet/core/__init__.py index bf32fab5f..ce86610ae 100644 --- a/mmdeploy/codebase/mmdet/core/__init__.py +++ b/mmdeploy/codebase/mmdet/core/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .anchor import * # noqa: F401,F403 from .bbox import * # noqa: F401,F403 from .ops import * # noqa: F401,F403 from .post_processing import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/core/anchor.py b/mmdeploy/codebase/mmdet/core/anchor.py new file mode 100644 index 000000000..4b8166f86 --- /dev/null +++ b/mmdeploy/codebase/mmdet/core/anchor.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Tuple + +import torch +from torch.onnx import symbolic_helper + +from mmdeploy.core import FUNCTION_REWRITER + + +class GridPriorsTRTOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int, + stride_w: int): + device = base_anchors.device + dtype = base_anchors.dtype + shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w + shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h + + def _meshgrid(x, y, row_major=True): + # use shape instead of len to keep tracing while exporting to onnx + xx = x.repeat(y.shape[0]) + yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1) + if row_major: + return xx, yy + else: + return yy, xx + + shift_xx, shift_yy = _meshgrid(shift_x, shift_y) + shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) + + all_anchors = base_anchors[None, :, :] + shifts[:, None, :] + all_anchors = all_anchors.view(-1, 4) + # then (0, 1), (0, 2), ... + return all_anchors + + @staticmethod + @symbolic_helper.parse_args('v', 'v', 'v', 'i', 'i') + def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int, + stride_w: int): + # zero_h and zero_w is used to provide shape to GridPriorsTRT + feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0]) + feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0]) + zero_h = g.op( + 'ConstantOfShape', + feat_h, + value_t=torch.tensor([0], dtype=torch.long), + ) + zero_w = g.op( + 'ConstantOfShape', + feat_w, + value_t=torch.tensor([0], dtype=torch.long), + ) + return g.op( + 'mmdeploy::GridPriorsTRT', + base_anchors, + zero_h, + zero_w, + stride_h_i=stride_h, + stride_w_i=stride_w) + + +grid_priors_trt = GridPriorsTRTOp.apply + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.core.anchor.anchor_generator.' + 'AnchorGenerator.single_level_grid_priors', + backend='tensorrt') +def anchorgenerator__single_level_grid_priors__trt( + ctx, + self, + featmap_size: Tuple[int], + level_idx: int, + dtype: torch.dtype = torch.float32, + device: str = 'cuda') -> torch.Tensor: + """This is a rewrite to replace ONNX anchor generator to TensorRT custom + op. + + Args: + ctx : The rewriter context + featmap_size (tuple[int]): Size of the feature maps. + level_idx (int): The index of corresponding feature map level. + dtype (obj:`torch.dtype`): Date type of points.Defaults to + ``torch.float32``. + device (str, optional): The device the tensor will be put on. + Defaults to 'cuda'. + + Returns: + torch.Tensor: Anchors in the overall feature maps. + """ + feat_h, feat_w = featmap_size + if isinstance(feat_h, int) and isinstance(feat_w, int): + return ctx.origin_func(self, featmap_size, level_idx, dtype, + device).data + base_anchors = self.base_anchors[level_idx].to(device).to(dtype) + stride_w, stride_h = self.strides[level_idx] + return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w) diff --git a/tests/test_codebase/test_mmdet/test_mmdet_core.py b/tests/test_codebase/test_mmdet/test_mmdet_core.py index 892e7bfdf..05d021cee 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_core.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_core.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +import tempfile + import mmcv import numpy as np import pytest import torch from mmdeploy.codebase import import_codebase +from mmdeploy.core.rewriters.rewriter_manager import RewriterContext from mmdeploy.utils import Backend, Codebase from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend, get_onnx_model, get_rewrite_outputs) @@ -223,3 +226,52 @@ def test_multiclass_nms_with_keep_top_k(pre_top_k): 'multiclass_nms returned more values than "keep_top_k"\n' \ f'dets.shape: {dets.shape}\n' \ f'keep_top_k: {keep_top_k}' + + +@backend_checker(Backend.TENSORRT) +def test__anchorgenerator__single_level_grid_priors(): + backend_type = 'tensorrt' + import onnx + from mmdet.core.anchor import AnchorGenerator + + from mmdeploy.apis.onnx import export + from mmdeploy.codebase.mmdet.core import anchor # noqa + + generator = AnchorGenerator( + scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4]) + + def single_level_grid_priors(input): + return generator.single_level_grid_priors(input.shape[2:], 0, + input.dtype, input.device) + + x = torch.rand(1, 3, 4, 4) + wrapped_func = WrapFunction(single_level_grid_priors) + output = wrapped_func(x) + + # test forward + with RewriterContext({}, backend_type): + wrap_output = wrapped_func(x) + torch.testing.assert_allclose(output, wrap_output) + + onnx_prefix = tempfile.NamedTemporaryFile().name + + export( + wrapped_func, + x, + onnx_prefix, + backend=backend_type, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + })) + + onnx_model = onnx.load(onnx_prefix + '.onnx') + + find_trt_grid_priors = False + for n in onnx_model.graph.node: + if n.op_type == 'GridPriorsTRT': + find_trt_grid_priors = True + + assert find_trt_grid_priors diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 97ed57329..e6121880d 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -1005,3 +1005,94 @@ def test_multi_level_rotated_roi_align(backend, output_names=['bbox_feats'], expected_result=expected_result, save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_TENSORRT]) +@pytest.mark.parametrize('strides', [(4, 4)]) +def test_trt_grid_priors(backend, strides, input_list=None, save_dir=None): + backend.check_env() + + if input_list is None: + input = torch.rand(1, 3, 2, 2) + base_anchors = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137], + [-16.0000, -16.0000, 16.0000, 16.0000], + [-11.3137, -22.6274, 11.3137, 22.6274]]) + + expected_result = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137], + [-16.0000, -16.0000, 16.0000, 16.0000], + [-11.3137, -22.6274, 11.3137, 22.6274], + [-18.6274, -11.3137, 26.6274, 11.3137], + [-12.0000, -16.0000, 20.0000, 16.0000], + [-7.3137, -22.6274, 15.3137, 22.6274], + [-22.6274, -7.3137, 22.6274, 15.3137], + [-16.0000, -12.0000, 16.0000, 20.0000], + [-11.3137, -18.6274, 11.3137, 26.6274], + [-18.6274, -7.3137, 26.6274, 15.3137], + [-12.0000, -12.0000, 20.0000, 20.0000], + [-7.3137, -18.6274, 15.3137, 26.6274]]) + else: + input = input_list[0] + base_anchors = input_list[1] + expected_result = input_list[2] + input_name = ['input'] + output_name = ['output'] + + class GridPriorsTestOps(torch.autograd.Function): + + @staticmethod + def forward(ctx, base_anchor, feat_h, feat_w, stride_h: int, + stride_w: int): + a = base_anchor.shape[0] + return base_anchor.new_empty(feat_h * feat_w * a, 4) + + @staticmethod + def symbolic(g, base_anchor, feat_h, feat_w, stride_h: int, + stride_w: int): + from torch.onnx import symbolic_helper + feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0]) + feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0]) + zero_h = g.op( + 'ConstantOfShape', + feat_h, + value_t=torch.tensor([0], dtype=torch.long), + ) + zero_w = g.op( + 'ConstantOfShape', + feat_w, + value_t=torch.tensor([0], dtype=torch.long), + ) + return g.op( + 'mmdeploy::GridPriorsTRT', + base_anchor, + zero_h, + zero_w, + stride_h_i=stride_h, + stride_w_i=stride_w) + + class GridPriorsTestModel(torch.nn.Module): + + def __init__(self, strides, base_anchors=base_anchors) -> None: + super().__init__() + self.strides = strides + self.base_anchors = base_anchors + + def forward(self, x): + base_anchors = self.base_anchors + h, w = x.shape[2:] + strides = self.strides + return GridPriorsTestOps.apply(base_anchors, h, w, strides[0], + strides[1]) + + model = GridPriorsTestModel(strides=strides) + + backend.run_and_validate( + model, [input], + 'trt_grid_priors', + input_names=input_name, + output_names=output_name, + expected_result=expected_result, + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + }), + save_dir=save_dir) diff --git a/tests/test_ops/utils.py b/tests/test_ops/utils.py index 588934bbd..52e563a37 100644 --- a/tests/test_ops/utils.py +++ b/tests/test_ops/utils.py @@ -97,6 +97,7 @@ class TestTensorRTExporter: os.makedirs(save_dir, exist_ok=True) onnx_file_path = os.path.join(save_dir, model_name + '.onnx') trt_file_path = os.path.join(save_dir, model_name + '.engine') + input_list = [data.cuda() for data in input_list] if isinstance(model, onnx.onnx_ml_pb2.ModelProto): onnx.save(model, onnx_file_path) else: @@ -152,7 +153,6 @@ class TestTensorRTExporter: from mmdeploy.backend.tensorrt import TRTWrapper trt_model = TRTWrapper(trt_file_path, output_names) - input_list = [data.cuda() for data in input_list] trt_outputs = trt_model(dict(zip(input_names, input_list))) trt_outputs = [trt_outputs[i].float().cpu() for i in output_names] assert_allclose(model_outputs, trt_outputs, tolerate_small_mismatch)