[CustomOps] TensorRT Gather Topk Ops (#1033)

* add gather topk

* add shape inference and document

* fix faster rcnn

* reshape topk

* fix
This commit is contained in:
q.yao 2022-09-19 13:48:26 +08:00 committed by GitHub
parent 50bd6b1703
commit 0caeaf238c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 476 additions and 20 deletions

View File

@ -0,0 +1,150 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "gather_topk.hpp"
#include <assert.h>
#include <stdio.h>
#include <chrono>
#include "NvInferVersion.h"
#include "gather_topk_kernel.hpp"
#include "trt_serialize.hpp"
namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"GatherTopk"};
} // namespace
GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {}
GatherTopk::GatherTopk(const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {}
nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT {
GatherTopk *plugin = new GatherTopk(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::DimsExprs GatherTopk::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
assert(inputs[0].nbDims >= inputs[1].nbDims);
nvinfer1::DimsExprs ret;
ret.nbDims = inputs[0].nbDims;
for (int i = 0; i < inputs[1].nbDims; ++i) {
ret.d[i] = inputs[1].d[i];
}
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) {
ret.d[i] = inputs[0].d[i];
}
return ret;
}
bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
switch (pos) {
case 0:
// data
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
case 1:
// indices
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 2:
// output
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
default:
return true;
}
return true;
}
void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}
size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
return 0;
}
int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
int indice_nbDims = inputDesc[1].dims.nbDims;
const void *data = inputs[0];
const void *indices = inputs[1];
void *output = outputs[0];
auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims,
indice_nbDims, (float *)output, stream);
break;
case nvinfer1::DataType::kINT32:
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims,
(int *)output, stream);
break;
default:
break;
}
return 0;
}
nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}
// IPluginV2 Methods
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; }
size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; }
void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {}
GatherTopkCreator::GatherTopkCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
auto *plugin = new GatherTopk(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new GatherTopk(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(GatherTopkCreator);
} // namespace mmdeploy

View File

@ -0,0 +1,64 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_SCATTERND_HPP
#define TRT_SCATTERND_HPP
#include <cublas_v2.h>
#include <memory>
#include <string>
#include <vector>
#include "trt_plugin_base.hpp"
namespace mmdeploy {
class GatherTopk : public TRTPluginBase {
public:
GatherTopk(const std::string &name);
GatherTopk(const std::string name, const void *data, size_t length);
GatherTopk() = delete;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;
// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;
};
class GatherTopkCreator : public TRTPluginCreatorBase {
public:
GatherTopkCreator();
const char *getPluginName() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_SCATTERND_HPP

View File

@ -0,0 +1,46 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <functional>
#include <numeric>
#include <vector>
#include "common_cuda_helper.hpp"
#include "gather_topk_kernel.hpp"
#include "trt_plugin_helper.hpp"
template <typename scalar_t>
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output,
int batch, int num_input, int num_indices, int channel) {
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) {
const int b_id = index / (num_indices * channel);
const int n_id = (index / channel) % num_indices;
const int c_id = index % channel;
const int input_n_id = indices[b_id * num_indices + n_id];
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id];
output[b_id * num_indices * channel + n_id * channel + c_id] = value;
}
}
template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream) {
int batch = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i];
int num_input = dims[indice_nbDims - 1];
int num_indices = indices_dims[indice_nbDims - 1];
int channel = 1;
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i];
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK);
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch,
num_input, num_indices, channel);
}
template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
float* output, cudaStream_t stream);
template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims,
int nbDims, const int* indices_dims, int indice_nbDims,
int32_t* output, cudaStream_t stream);

View File

@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP
#define TRT_GRID_SAMPLER_KERNEL_HPP
#include <cuda_runtime.h>
template <typename scalar_t>
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
const int* indices_dims, int indice_nbDims, scalar_t* output,
cudaStream_t stream);
#endif // TRT_GRID_SAMPLER_KERNEL_HPP

View File

@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)
<!-- TOC -->
@ -447,3 +453,39 @@ None
#### Type Constraints
- T:tensor(float32, Linear)
### GatherTopk
#### Description
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
```python
data[batch_index, bbox_index, ...]
```
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
#### Parameters
None
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear)

View File

@ -63,6 +63,12 @@
- [Inputs](#inputs-9)
- [Outputs](#outputs-9)
- [Type Constraints](#type-constraints-9)
- [GatherTopk](#gathertopk)
- [Description](#description-10)
- [Parameters](#parameters-10)
- [Inputs](#inputs-10)
- [Outputs](#outputs-10)
- [Type Constraints](#type-constraints-10)
<!-- TOC -->
@ -447,3 +453,39 @@ None
#### Type Constraints
- T:tensor(float32, Linear)
### GatherTopk
#### Description
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
```python
data[batch_index, bbox_index, ...]
```
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
#### Parameters
None
#### Inputs
<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
#### Outputs
<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
</dl>
#### Type Constraints
- T:tensor(float32, Linear), tensor(int32, Linear)

View File

@ -261,6 +261,13 @@ def multiclass_nms_static(ctx,
pre_top_k, keep_top_k, iou_threshold,
score_threshold, -1)
# retain shape info
batch_size = boxes.size(0)
dets_shape = dets.shape
label_shape = labels.shape
dets = dets.reshape([batch_size, *dets_shape[1:]])
labels = labels.reshape([batch_size, *label_shape[1:]])
return dets, labels

View File

@ -200,6 +200,53 @@ def __gather_topk(*inputs: Sequence[torch.Tensor],
return outputs
class TRTGatherTopk(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, inds: torch.Tensor):
"""Implement of gather topk."""
batch_size = x.size(0)
batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)
return x[batch_inds, inds, ...]
@staticmethod
def symbolic(g, x, inds):
"""symbolic of gather topk."""
out = g.op('mmdeploy::GatherTopk', x, inds, outputs=1)
return out
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.TENSORRT.value)
def __gather_topk__trt(ctx,
*inputs: Sequence[torch.Tensor],
inds: torch.Tensor,
batch_size: int,
is_batched: bool = True) -> Tuple[torch.Tensor]:
"""TensorRT gather_topk."""
_ = ctx
if is_batched:
index_shape = inds.shape
index_dim = inds.dim()
outputs = [None for _ in inputs]
for i, x in enumerate(inputs):
if x is None:
continue
out = TRTGatherTopk.apply(x, inds).to(x.dtype)
out_shape = [*index_shape, *x.shape[index_dim:]]
out = out.reshape(out_shape)
outputs[i] = out
else:
prior_inds = inds.new_zeros((1, 1))
outputs = [
x[prior_inds, inds, ...] if x is not None else None for x in inputs
]
return outputs
@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
backend=Backend.COREML.value)

View File

@ -2,7 +2,7 @@
import torch
import torch.nn.functional as F
from mmdeploy.codebase.mmdet import (get_post_processing_params,
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
multiclass_nms, pad_with_value)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
@ -129,14 +129,18 @@ def gfl_head__get_bbox(ctx,
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if with_score_factors:
score_factors = score_factors[batch_inds, topk_inds, :]
bbox_pred, scores, score_factors = gather_topk(
bbox_pred,
scores,
score_factors,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
priors = gather_topk(
priors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)

View File

@ -3,7 +3,7 @@ from typing import Sequence
import torch
from mmdeploy.codebase.mmdet import (get_post_processing_params,
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
@ -37,11 +37,14 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
Use `self.moment_transfer` in `points2bbox` will cause error:
RuntimeError: Input, output and indices must be on the current device
"""
moment_transfer = self.moment_transfer
delattr(self, 'moment_transfer')
self.moment_transfer = torch.tensor(moment_transfer.data)
update_moment = hasattr(self, 'moment_transfer')
if update_moment:
moment_transfer = self.moment_transfer
delattr(self, 'moment_transfer')
self.moment_transfer = torch.tensor(moment_transfer.data)
ret = ctx.origin_func(self, pts, y_first=y_first)
self.moment_transfer = moment_transfer
if update_moment:
self.moment_transfer = moment_transfer
return ret
@ -131,12 +134,17 @@ def reppoints_head__get_bboxes(ctx,
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size, device=bbox_pred.device).unsqueeze(-1)
prior_inds = batch_inds.new_zeros((1, 1))
priors = priors[prior_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
bbox_pred, scores = gather_topk(
bbox_pred,
scores,
inds=topk_inds,
batch_size=batch_size,
is_batched=True)
priors = gather_topk(
priors,
inds=topk_inds,
batch_size=batch_size,
is_batched=False)
bbox_pred = _bbox_pre_decode(priors, bbox_pred,
self.point_strides[level_idx])

View File

@ -1124,3 +1124,39 @@ def test_dot_product_attention(backend, save_dir=None):
input_names=['query', 'key', 'value'],
output_names=['out', 'attn'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
def test_gather_topk(backend, save_dir=None):
backend.check_env()
from mmdeploy.codebase.mmdet.deploy.utils import gather_topk
x = torch.rand(2, 10, 4).cuda()
class TestModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
batch_size = x.size(0)
max_x, _ = x.max(-1)
_, inds = max_x.topk(4)
new_x = gather_topk(x, inds=inds, batch_size=batch_size)
return new_x
model = TestModel().cuda()
with RewriterContext(
Config({'backend_config': {
'type': backend.backend_name
}}),
backend=backend.backend_name,
opset=11):
backend.run_and_validate(
model, [x],
'gather_topk',
input_names=['x'],
output_names=['out'],
save_dir=save_dir)