mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Enhancement] TensorRT Anchor generator plugin (#646)
* custom trt anchor generator * add ut * add docstring, update doc
This commit is contained in:
parent
4d9e20960d
commit
dc5f9c3746
@ -26,6 +26,21 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
|
||||
}
|
||||
const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); }
|
||||
|
||||
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) TRT_NOEXCEPT override {}
|
||||
|
||||
virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT override {
|
||||
return 0;
|
||||
}
|
||||
|
||||
virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
|
||||
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {}
|
||||
|
||||
virtual void detachFromContext() TRT_NOEXCEPT override {}
|
||||
|
||||
protected:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
@ -34,10 +49,8 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
#endif
|
||||
|
@ -0,0 +1,154 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include "trt_grid_priors.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_grid_priors_kernel.hpp"
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
using namespace nvinfer1;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"GridPriorsTRT"};
|
||||
} // namespace
|
||||
|
||||
GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride)
|
||||
: TRTPluginBase(name), mStride(stride) {}
|
||||
|
||||
GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length)
|
||||
: TRTPluginBase(name) {
|
||||
deserialize_value(&data, &length, &mStride);
|
||||
}
|
||||
GridPriorsTRT::~GridPriorsTRT() {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT {
|
||||
GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
// input[0] == base_anchor
|
||||
// input[1] == empty_h
|
||||
// input[2] == empty_w
|
||||
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = 2;
|
||||
auto area =
|
||||
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]);
|
||||
ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0]));
|
||||
ret.d[1] = exprBuilder.constant(4);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
|
||||
if (pos == 0) {
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
} else if (pos - nbInputs == 0) {
|
||||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) TRT_NOEXCEPT {
|
||||
int num_base_anchors = inputDesc[0].dims.d[0];
|
||||
int feat_h = inputDesc[1].dims.d[0];
|
||||
int feat_w = inputDesc[2].dims.d[0];
|
||||
|
||||
const void *base_anchor = inputs[0];
|
||||
void *output = outputs[0];
|
||||
|
||||
auto data_type = inputDesc[0].type;
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
trt_grid_priors_impl<float>((float *)base_anchor, (float *)output, num_base_anchors, feat_w,
|
||||
feat_h, mStride.d[0], mStride.d[1], stream);
|
||||
break;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||
|
||||
int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); }
|
||||
|
||||
void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT {
|
||||
serialize_value(&buffer, mStride);
|
||||
;
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
GridPriorsTRTCreator::GridPriorsTRTCreator() {
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h"));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w"));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||
|
||||
nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
int stride_w = 1;
|
||||
int stride_h = 1;
|
||||
|
||||
for (int i = 0; i < fc->nbFields; i++) {
|
||||
if (fc->fields[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string field_name(fc->fields[i].name);
|
||||
|
||||
if (field_name.compare("stride_w") == 0) {
|
||||
stride_w = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
if (field_name.compare("stride_h") == 0) {
|
||||
stride_h = static_cast<const int *>(fc->fields[i].data)[0];
|
||||
}
|
||||
}
|
||||
nvinfer1::Dims stride{2, {stride_w, stride_h}};
|
||||
|
||||
GridPriorsTRT *plugin = new GridPriorsTRT(name, stride);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new GridPriorsTRT(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator);
|
||||
} // namespace mmdeploy
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef TRT_GRID_PRIORS_HPP
|
||||
#define TRT_GRID_PRIORS_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_base.hpp"
|
||||
|
||||
namespace mmdeploy {
|
||||
class GridPriorsTRT : public TRTPluginBase {
|
||||
public:
|
||||
GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride);
|
||||
|
||||
GridPriorsTRT(const std::string name, const void *data, size_t length);
|
||||
|
||||
GridPriorsTRT() = delete;
|
||||
|
||||
~GridPriorsTRT() TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
|
||||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
|
||||
TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
|
||||
private:
|
||||
nvinfer1::Dims mStride;
|
||||
|
||||
cublasHandle_t m_cublas_handle;
|
||||
};
|
||||
|
||||
class GridPriorsTRTCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
GridPriorsTRTCreator();
|
||||
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmdeploy
|
||||
#endif // TRT_GRID_PRIORS_HPP
|
@ -0,0 +1,43 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_grid_priors_kernel.hpp"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output,
|
||||
int num_base_anchors, int feat_w, int feat_h, int stride_w,
|
||||
int stride_h) {
|
||||
// load base anchor into shared memory.
|
||||
extern __shared__ scalar_t shared_base_anchor[];
|
||||
for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) {
|
||||
shared_base_anchor[i] = base_anchor[i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) {
|
||||
const int a_offset = (index % num_base_anchors) << 2;
|
||||
const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w);
|
||||
const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h);
|
||||
|
||||
auto out_start = output + index * 4;
|
||||
out_start[0] = shared_base_anchor[a_offset] + w;
|
||||
out_start[1] = shared_base_anchor[a_offset + 1] + h;
|
||||
out_start[2] = shared_base_anchor[a_offset + 2] + w;
|
||||
out_start[3] = shared_base_anchor[a_offset + 3] + h;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
|
||||
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) {
|
||||
trt_grid_priors_kernel<<<GET_BLOCKS(num_base_anchors * feat_w * feat_h), THREADS_PER_BLOCK,
|
||||
DIVUP(num_base_anchors * 4, 32) * 32 * sizeof(scalar_t), stream>>>(
|
||||
base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w,
|
||||
(int)stride_h);
|
||||
}
|
||||
|
||||
template void trt_grid_priors_impl<float>(const float* base_anchor, float* output,
|
||||
int num_base_anchors, int feat_w, int feat_h,
|
||||
int stride_w, int stride_h, cudaStream_t stream);
|
@ -0,0 +1,10 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved
|
||||
#ifndef TRT_GRID_PRIORS_KERNEL_HPP
|
||||
#define TRT_GRID_PRIORS_KERNEL_HPP
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
|
||||
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream);
|
||||
|
||||
#endif
|
@ -51,6 +51,12 @@
|
||||
- [Inputs](#inputs-7)
|
||||
- [Outputs](#outputs-7)
|
||||
- [Type Constraints](#type-constraints-7)
|
||||
- [GridPriorsTRT](#gridpriorstrt)
|
||||
- [Description](#description-8)
|
||||
- [Parameters](#parameters-8)
|
||||
- [Inputs](#inputs-8)
|
||||
- [Outputs](#outputs-8)
|
||||
- [Type Constraints](#type-constraints-8)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
@ -363,3 +369,39 @@ Batched rotated NMS with a fixed number of output bounding boxes.
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
### GridPriorsTRT
|
||||
|
||||
#### Description
|
||||
|
||||
Generate the anchors for object detection task.
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Type | Parameter | Description |
|
||||
| ----- | ---------- | --------------------------------- |
|
||||
| `int` | `stride_w` | The stride of the feature width. |
|
||||
| `int` | `stride_h` | The stride of the feature height. |
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>The base anchors; 2-D tensor with shape [num_base_anchor, 4].</dd>
|
||||
<dt><tt>inputs[1]</tt>: TAny</dt>
|
||||
<dd>height provider; 1-D tensor with shape [featmap_height]. The data will never been used.</dd>
|
||||
<dt><tt>inputs[2]</tt>: TAny</dt>
|
||||
<dd>width provider; 1-D tensor with shape [featmap_width]. The data will never been used.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>output anchors; 2-D tensor of shape (num_base_anchor*featmap_height*featmap_widht, 4).</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
- TAny: Any
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .anchor import * # noqa: F401,F403
|
||||
from .bbox import * # noqa: F401,F403
|
||||
from .ops import * # noqa: F401,F403
|
||||
from .post_processing import * # noqa: F401,F403
|
||||
|
98
mmdeploy/codebase/mmdet/core/anchor.py
Normal file
98
mmdeploy/codebase/mmdet/core/anchor.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.onnx import symbolic_helper
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
|
||||
|
||||
class GridPriorsTRTOp(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, base_anchors, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
device = base_anchors.device
|
||||
dtype = base_anchors.dtype
|
||||
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
|
||||
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
|
||||
|
||||
def _meshgrid(x, y, row_major=True):
|
||||
# use shape instead of len to keep tracing while exporting to onnx
|
||||
xx = x.repeat(y.shape[0])
|
||||
yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
|
||||
if row_major:
|
||||
return xx, yy
|
||||
else:
|
||||
return yy, xx
|
||||
|
||||
shift_xx, shift_yy = _meshgrid(shift_x, shift_y)
|
||||
shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
|
||||
|
||||
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
|
||||
all_anchors = all_anchors.view(-1, 4)
|
||||
# then (0, 1), (0, 2), ...
|
||||
return all_anchors
|
||||
|
||||
@staticmethod
|
||||
@symbolic_helper.parse_args('v', 'v', 'v', 'i', 'i')
|
||||
def symbolic(g, base_anchors, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
# zero_h and zero_w is used to provide shape to GridPriorsTRT
|
||||
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
|
||||
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
|
||||
zero_h = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_h,
|
||||
value_t=torch.tensor([0], dtype=torch.long),
|
||||
)
|
||||
zero_w = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_w,
|
||||
value_t=torch.tensor([0], dtype=torch.long),
|
||||
)
|
||||
return g.op(
|
||||
'mmdeploy::GridPriorsTRT',
|
||||
base_anchors,
|
||||
zero_h,
|
||||
zero_w,
|
||||
stride_h_i=stride_h,
|
||||
stride_w_i=stride_w)
|
||||
|
||||
|
||||
grid_priors_trt = GridPriorsTRTOp.apply
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmdet.core.anchor.anchor_generator.'
|
||||
'AnchorGenerator.single_level_grid_priors',
|
||||
backend='tensorrt')
|
||||
def anchorgenerator__single_level_grid_priors__trt(
|
||||
ctx,
|
||||
self,
|
||||
featmap_size: Tuple[int],
|
||||
level_idx: int,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: str = 'cuda') -> torch.Tensor:
|
||||
"""This is a rewrite to replace ONNX anchor generator to TensorRT custom
|
||||
op.
|
||||
|
||||
Args:
|
||||
ctx : The rewriter context
|
||||
featmap_size (tuple[int]): Size of the feature maps.
|
||||
level_idx (int): The index of corresponding feature map level.
|
||||
dtype (obj:`torch.dtype`): Date type of points.Defaults to
|
||||
``torch.float32``.
|
||||
device (str, optional): The device the tensor will be put on.
|
||||
Defaults to 'cuda'.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Anchors in the overall feature maps.
|
||||
"""
|
||||
feat_h, feat_w = featmap_size
|
||||
if isinstance(feat_h, int) and isinstance(feat_w, int):
|
||||
return ctx.origin_func(self, featmap_size, level_idx, dtype,
|
||||
device).data
|
||||
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
|
||||
stride_w, stride_h = self.strides[level_idx]
|
||||
return grid_priors_trt(base_anchors, feat_h, feat_w, stride_h, stride_w)
|
@ -1,10 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import (WrapFunction, backend_checker, check_backend,
|
||||
get_onnx_model, get_rewrite_outputs)
|
||||
@ -223,3 +226,52 @@ def test_multiclass_nms_with_keep_top_k(pre_top_k):
|
||||
'multiclass_nms returned more values than "keep_top_k"\n' \
|
||||
f'dets.shape: {dets.shape}\n' \
|
||||
f'keep_top_k: {keep_top_k}'
|
||||
|
||||
|
||||
@backend_checker(Backend.TENSORRT)
|
||||
def test__anchorgenerator__single_level_grid_priors():
|
||||
backend_type = 'tensorrt'
|
||||
import onnx
|
||||
from mmdet.core.anchor import AnchorGenerator
|
||||
|
||||
from mmdeploy.apis.onnx import export
|
||||
from mmdeploy.codebase.mmdet.core import anchor # noqa
|
||||
|
||||
generator = AnchorGenerator(
|
||||
scales=[8], ratios=[0.5, 1.0, 2.0], strides=[4])
|
||||
|
||||
def single_level_grid_priors(input):
|
||||
return generator.single_level_grid_priors(input.shape[2:], 0,
|
||||
input.dtype, input.device)
|
||||
|
||||
x = torch.rand(1, 3, 4, 4)
|
||||
wrapped_func = WrapFunction(single_level_grid_priors)
|
||||
output = wrapped_func(x)
|
||||
|
||||
# test forward
|
||||
with RewriterContext({}, backend_type):
|
||||
wrap_output = wrapped_func(x)
|
||||
torch.testing.assert_allclose(output, wrap_output)
|
||||
|
||||
onnx_prefix = tempfile.NamedTemporaryFile().name
|
||||
|
||||
export(
|
||||
wrapped_func,
|
||||
x,
|
||||
onnx_prefix,
|
||||
backend=backend_type,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
dynamic_axes=dict(input={
|
||||
2: 'h',
|
||||
3: 'w'
|
||||
}))
|
||||
|
||||
onnx_model = onnx.load(onnx_prefix + '.onnx')
|
||||
|
||||
find_trt_grid_priors = False
|
||||
for n in onnx_model.graph.node:
|
||||
if n.op_type == 'GridPriorsTRT':
|
||||
find_trt_grid_priors = True
|
||||
|
||||
assert find_trt_grid_priors
|
||||
|
@ -1005,3 +1005,94 @@ def test_multi_level_rotated_roi_align(backend,
|
||||
output_names=['bbox_feats'],
|
||||
expected_result=expected_result,
|
||||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
@pytest.mark.parametrize('strides', [(4, 4)])
|
||||
def test_trt_grid_priors(backend, strides, input_list=None, save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
if input_list is None:
|
||||
input = torch.rand(1, 3, 2, 2)
|
||||
base_anchors = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
|
||||
[-16.0000, -16.0000, 16.0000, 16.0000],
|
||||
[-11.3137, -22.6274, 11.3137, 22.6274]])
|
||||
|
||||
expected_result = torch.tensor([[-22.6274, -11.3137, 22.6274, 11.3137],
|
||||
[-16.0000, -16.0000, 16.0000, 16.0000],
|
||||
[-11.3137, -22.6274, 11.3137, 22.6274],
|
||||
[-18.6274, -11.3137, 26.6274, 11.3137],
|
||||
[-12.0000, -16.0000, 20.0000, 16.0000],
|
||||
[-7.3137, -22.6274, 15.3137, 22.6274],
|
||||
[-22.6274, -7.3137, 22.6274, 15.3137],
|
||||
[-16.0000, -12.0000, 16.0000, 20.0000],
|
||||
[-11.3137, -18.6274, 11.3137, 26.6274],
|
||||
[-18.6274, -7.3137, 26.6274, 15.3137],
|
||||
[-12.0000, -12.0000, 20.0000, 20.0000],
|
||||
[-7.3137, -18.6274, 15.3137, 26.6274]])
|
||||
else:
|
||||
input = input_list[0]
|
||||
base_anchors = input_list[1]
|
||||
expected_result = input_list[2]
|
||||
input_name = ['input']
|
||||
output_name = ['output']
|
||||
|
||||
class GridPriorsTestOps(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, base_anchor, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
a = base_anchor.shape[0]
|
||||
return base_anchor.new_empty(feat_h * feat_w * a, 4)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, base_anchor, feat_h, feat_w, stride_h: int,
|
||||
stride_w: int):
|
||||
from torch.onnx import symbolic_helper
|
||||
feat_h = symbolic_helper._unsqueeze_helper(g, feat_h, [0])
|
||||
feat_w = symbolic_helper._unsqueeze_helper(g, feat_w, [0])
|
||||
zero_h = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_h,
|
||||
value_t=torch.tensor([0], dtype=torch.long),
|
||||
)
|
||||
zero_w = g.op(
|
||||
'ConstantOfShape',
|
||||
feat_w,
|
||||
value_t=torch.tensor([0], dtype=torch.long),
|
||||
)
|
||||
return g.op(
|
||||
'mmdeploy::GridPriorsTRT',
|
||||
base_anchor,
|
||||
zero_h,
|
||||
zero_w,
|
||||
stride_h_i=stride_h,
|
||||
stride_w_i=stride_w)
|
||||
|
||||
class GridPriorsTestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, strides, base_anchors=base_anchors) -> None:
|
||||
super().__init__()
|
||||
self.strides = strides
|
||||
self.base_anchors = base_anchors
|
||||
|
||||
def forward(self, x):
|
||||
base_anchors = self.base_anchors
|
||||
h, w = x.shape[2:]
|
||||
strides = self.strides
|
||||
return GridPriorsTestOps.apply(base_anchors, h, w, strides[0],
|
||||
strides[1])
|
||||
|
||||
model = GridPriorsTestModel(strides=strides)
|
||||
|
||||
backend.run_and_validate(
|
||||
model, [input],
|
||||
'trt_grid_priors',
|
||||
input_names=input_name,
|
||||
output_names=output_name,
|
||||
expected_result=expected_result,
|
||||
dynamic_axes=dict(input={
|
||||
2: 'h',
|
||||
3: 'w'
|
||||
}),
|
||||
save_dir=save_dir)
|
||||
|
@ -97,6 +97,7 @@ class TestTensorRTExporter:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
onnx_file_path = os.path.join(save_dir, model_name + '.onnx')
|
||||
trt_file_path = os.path.join(save_dir, model_name + '.engine')
|
||||
input_list = [data.cuda() for data in input_list]
|
||||
if isinstance(model, onnx.onnx_ml_pb2.ModelProto):
|
||||
onnx.save(model, onnx_file_path)
|
||||
else:
|
||||
@ -152,7 +153,6 @@ class TestTensorRTExporter:
|
||||
|
||||
from mmdeploy.backend.tensorrt import TRTWrapper
|
||||
trt_model = TRTWrapper(trt_file_path, output_names)
|
||||
input_list = [data.cuda() for data in input_list]
|
||||
trt_outputs = trt_model(dict(zip(input_names, input_list)))
|
||||
trt_outputs = [trt_outputs[i].float().cpu() for i in output_names]
|
||||
assert_allclose(model_outputs, trt_outputs, tolerate_small_mismatch)
|
||||
|
Loading…
x
Reference in New Issue
Block a user