mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[CustomOps] TensorRT Gather Topk Ops (#1033)
* add gather topk * add shape inference and document * fix faster rcnn * reshape topk * fix
This commit is contained in:
parent
50bd6b1703
commit
0caeaf238c
150
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Normal file
150
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
#include "gather_topk.hpp"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
|
||||||
|
#include "NvInferVersion.h"
|
||||||
|
#include "gather_topk_kernel.hpp"
|
||||||
|
#include "trt_serialize.hpp"
|
||||||
|
|
||||||
|
namespace mmdeploy {
|
||||||
|
namespace {
|
||||||
|
static const char *PLUGIN_VERSION{"1"};
|
||||||
|
static const char *PLUGIN_NAME{"GatherTopk"};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {}
|
||||||
|
|
||||||
|
GatherTopk::GatherTopk(const std::string name, const void *data, size_t length)
|
||||||
|
: TRTPluginBase(name) {}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT {
|
||||||
|
GatherTopk *plugin = new GatherTopk(mLayerName);
|
||||||
|
plugin->setPluginNamespace(getPluginNamespace());
|
||||||
|
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DimsExprs GatherTopk::getOutputDimensions(
|
||||||
|
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||||
|
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||||
|
assert(inputs[0].nbDims >= inputs[1].nbDims);
|
||||||
|
nvinfer1::DimsExprs ret;
|
||||||
|
ret.nbDims = inputs[0].nbDims;
|
||||||
|
for (int i = 0; i < inputs[1].nbDims; ++i) {
|
||||||
|
ret.d[i] = inputs[1].d[i];
|
||||||
|
}
|
||||||
|
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) {
|
||||||
|
ret.d[i] = inputs[0].d[i];
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
|
||||||
|
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
|
||||||
|
switch (pos) {
|
||||||
|
case 0:
|
||||||
|
// data
|
||||||
|
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||||
|
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||||
|
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||||
|
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||||
|
case 1:
|
||||||
|
// indices
|
||||||
|
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||||
|
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||||
|
case 2:
|
||||||
|
// output
|
||||||
|
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
|
||||||
|
default:
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||||
|
int nbOutputs) TRT_NOEXCEPT {}
|
||||||
|
|
||||||
|
size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc *outputs,
|
||||||
|
int nbOutputs) const TRT_NOEXCEPT {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||||
|
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||||
|
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||||
|
const int *indices_dims = &(inputDesc[1].dims.d[0]);
|
||||||
|
int nbDims = inputDesc[0].dims.nbDims;
|
||||||
|
int indice_nbDims = inputDesc[1].dims.nbDims;
|
||||||
|
|
||||||
|
const void *data = inputs[0];
|
||||||
|
const void *indices = inputs[1];
|
||||||
|
void *output = outputs[0];
|
||||||
|
|
||||||
|
auto data_type = inputDesc[0].type;
|
||||||
|
|
||||||
|
switch (data_type) {
|
||||||
|
case nvinfer1::DataType::kFLOAT:
|
||||||
|
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims,
|
||||||
|
indice_nbDims, (float *)output, stream);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case nvinfer1::DataType::kINT32:
|
||||||
|
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims,
|
||||||
|
(int *)output, stream);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||||
|
int nbInputs) const TRT_NOEXCEPT {
|
||||||
|
return inputTypes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPluginV2 Methods
|
||||||
|
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||||
|
|
||||||
|
const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||||
|
|
||||||
|
int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||||
|
|
||||||
|
size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; }
|
||||||
|
|
||||||
|
void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {}
|
||||||
|
|
||||||
|
GatherTopkCreator::GatherTopkCreator() {
|
||||||
|
mPluginAttributes.clear();
|
||||||
|
mFC.nbFields = mPluginAttributes.size();
|
||||||
|
mFC.fields = mPluginAttributes.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||||
|
|
||||||
|
const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin(
|
||||||
|
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||||
|
auto *plugin = new GatherTopk(name);
|
||||||
|
plugin->setPluginNamespace(getPluginNamespace());
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData,
|
||||||
|
size_t serialLength) TRT_NOEXCEPT {
|
||||||
|
auto plugin = new GatherTopk(name, serialData, serialLength);
|
||||||
|
plugin->setPluginNamespace(getPluginNamespace());
|
||||||
|
return plugin;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_TENSORRT_PLUGIN(GatherTopkCreator);
|
||||||
|
} // namespace mmdeploy
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
#ifndef TRT_SCATTERND_HPP
|
||||||
|
#define TRT_SCATTERND_HPP
|
||||||
|
#include <cublas_v2.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "trt_plugin_base.hpp"
|
||||||
|
|
||||||
|
namespace mmdeploy {
|
||||||
|
class GatherTopk : public TRTPluginBase {
|
||||||
|
public:
|
||||||
|
GatherTopk(const std::string &name);
|
||||||
|
|
||||||
|
GatherTopk(const std::string name, const void *data, size_t length);
|
||||||
|
|
||||||
|
GatherTopk() = delete;
|
||||||
|
|
||||||
|
// IPluginV2DynamicExt Methods
|
||||||
|
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||||
|
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
|
||||||
|
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
|
||||||
|
TRT_NOEXCEPT override;
|
||||||
|
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||||
|
int nbOutputs) TRT_NOEXCEPT override;
|
||||||
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||||
|
int nbOutputs) TRT_NOEXCEPT override;
|
||||||
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||||
|
const nvinfer1::PluginTensorDesc *outputs,
|
||||||
|
int nbOutputs) const TRT_NOEXCEPT override;
|
||||||
|
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||||
|
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||||
|
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
|
||||||
|
|
||||||
|
// IPluginV2Ext Methods
|
||||||
|
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||||
|
int nbInputs) const TRT_NOEXCEPT override;
|
||||||
|
|
||||||
|
// IPluginV2 Methods
|
||||||
|
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||||
|
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||||
|
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||||
|
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||||
|
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GatherTopkCreator : public TRTPluginCreatorBase {
|
||||||
|
public:
|
||||||
|
GatherTopkCreator();
|
||||||
|
|
||||||
|
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||||
|
|
||||||
|
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||||
|
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||||
|
TRT_NOEXCEPT override;
|
||||||
|
|
||||||
|
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
|
||||||
|
size_t serialLength) TRT_NOEXCEPT override;
|
||||||
|
};
|
||||||
|
} // namespace mmdeploy
|
||||||
|
#endif // TRT_SCATTERND_HPP
|
@ -0,0 +1,46 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "common_cuda_helper.hpp"
|
||||||
|
#include "gather_topk_kernel.hpp"
|
||||||
|
#include "trt_plugin_helper.hpp"
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output,
|
||||||
|
int batch, int num_input, int num_indices, int channel) {
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) {
|
||||||
|
const int b_id = index / (num_indices * channel);
|
||||||
|
const int n_id = (index / channel) % num_indices;
|
||||||
|
const int c_id = index % channel;
|
||||||
|
|
||||||
|
const int input_n_id = indices[b_id * num_indices + n_id];
|
||||||
|
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id];
|
||||||
|
output[b_id * num_indices * channel + n_id * channel + c_id] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
|
||||||
|
const int* indices_dims, int indice_nbDims, scalar_t* output,
|
||||||
|
cudaStream_t stream) {
|
||||||
|
int batch = 1;
|
||||||
|
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i];
|
||||||
|
int num_input = dims[indice_nbDims - 1];
|
||||||
|
int num_indices = indices_dims[indice_nbDims - 1];
|
||||||
|
int channel = 1;
|
||||||
|
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i];
|
||||||
|
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK);
|
||||||
|
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch,
|
||||||
|
num_input, num_indices, channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims,
|
||||||
|
int nbDims, const int* indices_dims, int indice_nbDims,
|
||||||
|
float* output, cudaStream_t stream);
|
||||||
|
|
||||||
|
template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims,
|
||||||
|
int nbDims, const int* indices_dims, int indice_nbDims,
|
||||||
|
int32_t* output, cudaStream_t stream);
|
@ -0,0 +1,10 @@
|
|||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP
|
||||||
|
#define TRT_GRID_SAMPLER_KERNEL_HPP
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
|
||||||
|
const int* indices_dims, int indice_nbDims, scalar_t* output,
|
||||||
|
cudaStream_t stream);
|
||||||
|
#endif // TRT_GRID_SAMPLER_KERNEL_HPP
|
@ -63,6 +63,12 @@
|
|||||||
- [Inputs](#inputs-9)
|
- [Inputs](#inputs-9)
|
||||||
- [Outputs](#outputs-9)
|
- [Outputs](#outputs-9)
|
||||||
- [Type Constraints](#type-constraints-9)
|
- [Type Constraints](#type-constraints-9)
|
||||||
|
- [GatherTopk](#gathertopk)
|
||||||
|
- [Description](#description-10)
|
||||||
|
- [Parameters](#parameters-10)
|
||||||
|
- [Inputs](#inputs-10)
|
||||||
|
- [Outputs](#outputs-10)
|
||||||
|
- [Type Constraints](#type-constraints-10)
|
||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
|
||||||
@ -447,3 +453,39 @@ None
|
|||||||
#### Type Constraints
|
#### Type Constraints
|
||||||
|
|
||||||
- T:tensor(float32, Linear)
|
- T:tensor(float32, Linear)
|
||||||
|
|
||||||
|
### GatherTopk
|
||||||
|
|
||||||
|
#### Description
|
||||||
|
|
||||||
|
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
|
||||||
|
|
||||||
|
```python
|
||||||
|
data[batch_index, bbox_index, ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
None
|
||||||
|
|
||||||
|
#### Inputs
|
||||||
|
|
||||||
|
<dl>
|
||||||
|
<dt><tt>inputs[0]</tt>: T</dt>
|
||||||
|
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
|
||||||
|
|
||||||
|
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
|
||||||
|
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
|
||||||
|
|
||||||
|
#### Outputs
|
||||||
|
|
||||||
|
<dl>
|
||||||
|
<dt><tt>outputs[0]</tt>: T</dt>
|
||||||
|
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
|
||||||
|
</dl>
|
||||||
|
|
||||||
|
#### Type Constraints
|
||||||
|
|
||||||
|
- T:tensor(float32, Linear), tensor(int32, Linear)
|
||||||
|
@ -63,6 +63,12 @@
|
|||||||
- [Inputs](#inputs-9)
|
- [Inputs](#inputs-9)
|
||||||
- [Outputs](#outputs-9)
|
- [Outputs](#outputs-9)
|
||||||
- [Type Constraints](#type-constraints-9)
|
- [Type Constraints](#type-constraints-9)
|
||||||
|
- [GatherTopk](#gathertopk)
|
||||||
|
- [Description](#description-10)
|
||||||
|
- [Parameters](#parameters-10)
|
||||||
|
- [Inputs](#inputs-10)
|
||||||
|
- [Outputs](#outputs-10)
|
||||||
|
- [Type Constraints](#type-constraints-10)
|
||||||
|
|
||||||
<!-- TOC -->
|
<!-- TOC -->
|
||||||
|
|
||||||
@ -447,3 +453,39 @@ None
|
|||||||
#### Type Constraints
|
#### Type Constraints
|
||||||
|
|
||||||
- T:tensor(float32, Linear)
|
- T:tensor(float32, Linear)
|
||||||
|
|
||||||
|
### GatherTopk
|
||||||
|
|
||||||
|
#### Description
|
||||||
|
|
||||||
|
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
|
||||||
|
|
||||||
|
```python
|
||||||
|
data[batch_index, bbox_index, ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
|
||||||
|
|
||||||
|
#### Parameters
|
||||||
|
|
||||||
|
None
|
||||||
|
|
||||||
|
#### Inputs
|
||||||
|
|
||||||
|
<dl>
|
||||||
|
<dt><tt>inputs[0]</tt>: T</dt>
|
||||||
|
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
|
||||||
|
|
||||||
|
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
|
||||||
|
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
|
||||||
|
|
||||||
|
#### Outputs
|
||||||
|
|
||||||
|
<dl>
|
||||||
|
<dt><tt>outputs[0]</tt>: T</dt>
|
||||||
|
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
|
||||||
|
</dl>
|
||||||
|
|
||||||
|
#### Type Constraints
|
||||||
|
|
||||||
|
- T:tensor(float32, Linear), tensor(int32, Linear)
|
||||||
|
@ -261,6 +261,13 @@ def multiclass_nms_static(ctx,
|
|||||||
pre_top_k, keep_top_k, iou_threshold,
|
pre_top_k, keep_top_k, iou_threshold,
|
||||||
score_threshold, -1)
|
score_threshold, -1)
|
||||||
|
|
||||||
|
# retain shape info
|
||||||
|
batch_size = boxes.size(0)
|
||||||
|
|
||||||
|
dets_shape = dets.shape
|
||||||
|
label_shape = labels.shape
|
||||||
|
dets = dets.reshape([batch_size, *dets_shape[1:]])
|
||||||
|
labels = labels.reshape([batch_size, *label_shape[1:]])
|
||||||
return dets, labels
|
return dets, labels
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,6 +200,53 @@ def __gather_topk(*inputs: Sequence[torch.Tensor],
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TRTGatherTopk(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, inds: torch.Tensor):
|
||||||
|
"""Implement of gather topk."""
|
||||||
|
batch_size = x.size(0)
|
||||||
|
batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)
|
||||||
|
return x[batch_inds, inds, ...]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g, x, inds):
|
||||||
|
"""symbolic of gather topk."""
|
||||||
|
out = g.op('mmdeploy::GatherTopk', x, inds, outputs=1)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
|
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||||
|
backend=Backend.TENSORRT.value)
|
||||||
|
def __gather_topk__trt(ctx,
|
||||||
|
*inputs: Sequence[torch.Tensor],
|
||||||
|
inds: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
is_batched: bool = True) -> Tuple[torch.Tensor]:
|
||||||
|
"""TensorRT gather_topk."""
|
||||||
|
_ = ctx
|
||||||
|
if is_batched:
|
||||||
|
index_shape = inds.shape
|
||||||
|
index_dim = inds.dim()
|
||||||
|
outputs = [None for _ in inputs]
|
||||||
|
for i, x in enumerate(inputs):
|
||||||
|
if x is None:
|
||||||
|
continue
|
||||||
|
out = TRTGatherTopk.apply(x, inds).to(x.dtype)
|
||||||
|
out_shape = [*index_shape, *x.shape[index_dim:]]
|
||||||
|
out = out.reshape(out_shape)
|
||||||
|
outputs[i] = out
|
||||||
|
else:
|
||||||
|
prior_inds = inds.new_zeros((1, 1))
|
||||||
|
outputs = [
|
||||||
|
x[prior_inds, inds, ...] if x is not None else None for x in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@FUNCTION_REWRITER.register_rewriter(
|
@FUNCTION_REWRITER.register_rewriter(
|
||||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||||
backend=Backend.COREML.value)
|
backend=Backend.COREML.value)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||||
multiclass_nms, pad_with_value)
|
multiclass_nms, pad_with_value)
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
|
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
|
||||||
@ -129,14 +129,18 @@ def gfl_head__get_bbox(ctx,
|
|||||||
else:
|
else:
|
||||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||||
_, topk_inds = max_scores.topk(pre_topk)
|
_, topk_inds = max_scores.topk(pre_topk)
|
||||||
batch_inds = torch.arange(
|
bbox_pred, scores, score_factors = gather_topk(
|
||||||
batch_size, device=bbox_pred.device).unsqueeze(-1)
|
bbox_pred,
|
||||||
prior_inds = batch_inds.new_zeros((1, 1))
|
scores,
|
||||||
priors = priors[prior_inds, topk_inds, :]
|
score_factors,
|
||||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
inds=topk_inds,
|
||||||
scores = scores[batch_inds, topk_inds, :]
|
batch_size=batch_size,
|
||||||
if with_score_factors:
|
is_batched=True)
|
||||||
score_factors = score_factors[batch_inds, topk_inds, :]
|
priors = gather_topk(
|
||||||
|
priors,
|
||||||
|
inds=topk_inds,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_batched=False)
|
||||||
|
|
||||||
mlvl_valid_bboxes.append(bbox_pred)
|
mlvl_valid_bboxes.append(bbox_pred)
|
||||||
mlvl_valid_scores.append(scores)
|
mlvl_valid_scores.append(scores)
|
||||||
|
@ -3,7 +3,7 @@ from typing import Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||||
multiclass_nms,
|
multiclass_nms,
|
||||||
pad_with_value_if_necessary)
|
pad_with_value_if_necessary)
|
||||||
from mmdeploy.core import FUNCTION_REWRITER
|
from mmdeploy.core import FUNCTION_REWRITER
|
||||||
@ -37,11 +37,14 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
|
|||||||
Use `self.moment_transfer` in `points2bbox` will cause error:
|
Use `self.moment_transfer` in `points2bbox` will cause error:
|
||||||
RuntimeError: Input, output and indices must be on the current device
|
RuntimeError: Input, output and indices must be on the current device
|
||||||
"""
|
"""
|
||||||
moment_transfer = self.moment_transfer
|
update_moment = hasattr(self, 'moment_transfer')
|
||||||
delattr(self, 'moment_transfer')
|
if update_moment:
|
||||||
self.moment_transfer = torch.tensor(moment_transfer.data)
|
moment_transfer = self.moment_transfer
|
||||||
|
delattr(self, 'moment_transfer')
|
||||||
|
self.moment_transfer = torch.tensor(moment_transfer.data)
|
||||||
ret = ctx.origin_func(self, pts, y_first=y_first)
|
ret = ctx.origin_func(self, pts, y_first=y_first)
|
||||||
self.moment_transfer = moment_transfer
|
if update_moment:
|
||||||
|
self.moment_transfer = moment_transfer
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@ -131,12 +134,17 @@ def reppoints_head__get_bboxes(ctx,
|
|||||||
else:
|
else:
|
||||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||||
_, topk_inds = max_scores.topk(pre_topk)
|
_, topk_inds = max_scores.topk(pre_topk)
|
||||||
batch_inds = torch.arange(
|
bbox_pred, scores = gather_topk(
|
||||||
batch_size, device=bbox_pred.device).unsqueeze(-1)
|
bbox_pred,
|
||||||
prior_inds = batch_inds.new_zeros((1, 1))
|
scores,
|
||||||
priors = priors[prior_inds, topk_inds, :]
|
inds=topk_inds,
|
||||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
batch_size=batch_size,
|
||||||
scores = scores[batch_inds, topk_inds, :]
|
is_batched=True)
|
||||||
|
priors = gather_topk(
|
||||||
|
priors,
|
||||||
|
inds=topk_inds,
|
||||||
|
batch_size=batch_size,
|
||||||
|
is_batched=False)
|
||||||
|
|
||||||
bbox_pred = _bbox_pre_decode(priors, bbox_pred,
|
bbox_pred = _bbox_pre_decode(priors, bbox_pred,
|
||||||
self.point_strides[level_idx])
|
self.point_strides[level_idx])
|
||||||
|
@ -1124,3 +1124,39 @@ def test_dot_product_attention(backend, save_dir=None):
|
|||||||
input_names=['query', 'key', 'value'],
|
input_names=['query', 'key', 'value'],
|
||||||
output_names=['out', 'attn'],
|
output_names=['out', 'attn'],
|
||||||
save_dir=save_dir)
|
save_dir=save_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||||
|
def test_gather_topk(backend, save_dir=None):
|
||||||
|
backend.check_env()
|
||||||
|
from mmdeploy.codebase.mmdet.deploy.utils import gather_topk
|
||||||
|
|
||||||
|
x = torch.rand(2, 10, 4).cuda()
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size = x.size(0)
|
||||||
|
max_x, _ = x.max(-1)
|
||||||
|
_, inds = max_x.topk(4)
|
||||||
|
|
||||||
|
new_x = gather_topk(x, inds=inds, batch_size=batch_size)
|
||||||
|
return new_x
|
||||||
|
|
||||||
|
model = TestModel().cuda()
|
||||||
|
|
||||||
|
with RewriterContext(
|
||||||
|
Config({'backend_config': {
|
||||||
|
'type': backend.backend_name
|
||||||
|
}}),
|
||||||
|
backend=backend.backend_name,
|
||||||
|
opset=11):
|
||||||
|
backend.run_and_validate(
|
||||||
|
model, [x],
|
||||||
|
'gather_topk',
|
||||||
|
input_names=['x'],
|
||||||
|
output_names=['out'],
|
||||||
|
save_dir=save_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user