mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[CustomOps] TensorRT Gather Topk Ops (#1033)
* add gather topk * add shape inference and document * fix faster rcnn * reshape topk * fix
This commit is contained in:
parent
50bd6b1703
commit
0caeaf238c
150
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Normal file
150
csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp
Normal file
@ -0,0 +1,150 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "gather_topk.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "NvInferVersion.h"
|
||||
#include "gather_topk_kernel.hpp"
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"GatherTopk"};
|
||||
} // namespace
|
||||
|
||||
GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {}
|
||||
|
||||
GatherTopk::GatherTopk(const std::string name, const void *data, size_t length)
|
||||
: TRTPluginBase(name) {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT {
|
||||
GatherTopk *plugin = new GatherTopk(mLayerName);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs GatherTopk::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
|
||||
assert(inputs[0].nbDims >= inputs[1].nbDims);
|
||||
nvinfer1::DimsExprs ret;
|
||||
ret.nbDims = inputs[0].nbDims;
|
||||
for (int i = 0; i < inputs[1].nbDims; ++i) {
|
||||
ret.d[i] = inputs[1].d[i];
|
||||
}
|
||||
for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) {
|
||||
ret.d[i] = inputs[0].d[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
|
||||
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
// data
|
||||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||
(ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
case 1:
|
||||
// indices
|
||||
return ioDesc[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
case 2:
|
||||
// output
|
||||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs,
|
||||
int nbOutputs) TRT_NOEXCEPT {}
|
||||
|
||||
size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
|
||||
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||
const int *indices_dims = &(inputDesc[1].dims.d[0]);
|
||||
int nbDims = inputDesc[0].dims.nbDims;
|
||||
int indice_nbDims = inputDesc[1].dims.nbDims;
|
||||
|
||||
const void *data = inputs[0];
|
||||
const void *indices = inputs[1];
|
||||
void *output = outputs[0];
|
||||
|
||||
auto data_type = inputDesc[0].type;
|
||||
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
gather_topk_impl<float>((float *)data, (int *)indices, dims, nbDims, indices_dims,
|
||||
indice_nbDims, (float *)output, stream);
|
||||
break;
|
||||
|
||||
case nvinfer1::DataType::kINT32:
|
||||
gather_topk_impl<int>((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims,
|
||||
(int *)output, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||
|
||||
int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; }
|
||||
|
||||
size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; }
|
||||
|
||||
void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {}
|
||||
|
||||
GatherTopkCreator::GatherTopkCreator() {
|
||||
mPluginAttributes.clear();
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }
|
||||
|
||||
const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }
|
||||
|
||||
nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
|
||||
auto *plugin = new GatherTopk(name);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT {
|
||||
auto plugin = new GatherTopk(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(GatherTopkCreator);
|
||||
} // namespace mmdeploy
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef TRT_SCATTERND_HPP
|
||||
#define TRT_SCATTERND_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_base.hpp"
|
||||
|
||||
namespace mmdeploy {
|
||||
class GatherTopk : public TRTPluginBase {
|
||||
public:
|
||||
GatherTopk(const std::string &name);
|
||||
|
||||
GatherTopk(const std::string name, const void *data, size_t length);
|
||||
|
||||
GatherTopk() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
|
||||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
|
||||
TRT_NOEXCEPT override;
|
||||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) TRT_NOEXCEPT override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const TRT_NOEXCEPT override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
|
||||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const TRT_NOEXCEPT override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const TRT_NOEXCEPT override;
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
int getNbOutputs() const TRT_NOEXCEPT override;
|
||||
size_t getSerializationSize() const TRT_NOEXCEPT override;
|
||||
void serialize(void *buffer) const TRT_NOEXCEPT override;
|
||||
};
|
||||
|
||||
class GatherTopkCreator : public TRTPluginCreatorBase {
|
||||
public:
|
||||
GatherTopkCreator();
|
||||
|
||||
const char *getPluginName() const TRT_NOEXCEPT override;
|
||||
|
||||
const char *getPluginVersion() const TRT_NOEXCEPT override;
|
||||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
|
||||
TRT_NOEXCEPT override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
|
||||
size_t serialLength) TRT_NOEXCEPT override;
|
||||
};
|
||||
} // namespace mmdeploy
|
||||
#endif // TRT_SCATTERND_HPP
|
@ -0,0 +1,46 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "gather_topk_kernel.hpp"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output,
|
||||
int batch, int num_input, int num_indices, int channel) {
|
||||
CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) {
|
||||
const int b_id = index / (num_indices * channel);
|
||||
const int n_id = (index / channel) % num_indices;
|
||||
const int c_id = index % channel;
|
||||
|
||||
const int input_n_id = indices[b_id * num_indices + n_id];
|
||||
const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id];
|
||||
output[b_id * num_indices * channel + n_id * channel + c_id] = value;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
|
||||
const int* indices_dims, int indice_nbDims, scalar_t* output,
|
||||
cudaStream_t stream) {
|
||||
int batch = 1;
|
||||
for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i];
|
||||
int num_input = dims[indice_nbDims - 1];
|
||||
int num_indices = indices_dims[indice_nbDims - 1];
|
||||
int channel = 1;
|
||||
for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i];
|
||||
const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK);
|
||||
gather_topk_kernel<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(input, indices, output, batch,
|
||||
num_input, num_indices, channel);
|
||||
}
|
||||
|
||||
template void gather_topk_impl<float>(const float* input, const int* indices, const int* dims,
|
||||
int nbDims, const int* indices_dims, int indice_nbDims,
|
||||
float* output, cudaStream_t stream);
|
||||
|
||||
template void gather_topk_impl<int32_t>(const int32_t* input, const int* indices, const int* dims,
|
||||
int nbDims, const int* indices_dims, int indice_nbDims,
|
||||
int32_t* output, cudaStream_t stream);
|
@ -0,0 +1,10 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef TRT_GRID_SAMPLER_KERNEL_HPP
|
||||
#define TRT_GRID_SAMPLER_KERNEL_HPP
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename scalar_t>
|
||||
void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims,
|
||||
const int* indices_dims, int indice_nbDims, scalar_t* output,
|
||||
cudaStream_t stream);
|
||||
#endif // TRT_GRID_SAMPLER_KERNEL_HPP
|
@ -63,6 +63,12 @@
|
||||
- [Inputs](#inputs-9)
|
||||
- [Outputs](#outputs-9)
|
||||
- [Type Constraints](#type-constraints-9)
|
||||
- [GatherTopk](#gathertopk)
|
||||
- [Description](#description-10)
|
||||
- [Parameters](#parameters-10)
|
||||
- [Inputs](#inputs-10)
|
||||
- [Outputs](#outputs-10)
|
||||
- [Type Constraints](#type-constraints-10)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
@ -447,3 +453,39 @@ None
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
### GatherTopk
|
||||
|
||||
#### Description
|
||||
|
||||
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
|
||||
|
||||
```python
|
||||
data[batch_index, bbox_index, ...]
|
||||
```
|
||||
|
||||
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
|
||||
|
||||
#### Parameters
|
||||
|
||||
None
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
|
||||
|
||||
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
|
||||
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear), tensor(int32, Linear)
|
||||
|
@ -63,6 +63,12 @@
|
||||
- [Inputs](#inputs-9)
|
||||
- [Outputs](#outputs-9)
|
||||
- [Type Constraints](#type-constraints-9)
|
||||
- [GatherTopk](#gathertopk)
|
||||
- [Description](#description-10)
|
||||
- [Parameters](#parameters-10)
|
||||
- [Inputs](#inputs-10)
|
||||
- [Outputs](#outputs-10)
|
||||
- [Type Constraints](#type-constraints-10)
|
||||
|
||||
<!-- TOC -->
|
||||
|
||||
@ -447,3 +453,39 @@ None
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear)
|
||||
|
||||
### GatherTopk
|
||||
|
||||
#### Description
|
||||
|
||||
TensorRT 8.2~8.4 would give unexpected result for multi-index gather.
|
||||
|
||||
```python
|
||||
data[batch_index, bbox_index, ...]
|
||||
```
|
||||
|
||||
Read [this](https://github.com/NVIDIA/TensorRT/issues/2299) for more details.
|
||||
|
||||
#### Parameters
|
||||
|
||||
None
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>inputs[0]</tt>: T</dt>
|
||||
<dd>Tensor to be gathered, with shape (A0, ..., An, G0, C0, ...).</dd>
|
||||
|
||||
<dt><tt>inputs[1]</tt>: tensor(int32, Linear)</dt>
|
||||
<dd>Tensor of index. with shape (A0, ..., An, G1)</dd>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>outputs[0]</tt>: T</dt>
|
||||
<dd>Tensor of output. With shape (A0, ..., An, G1, C0, ...)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
- T:tensor(float32, Linear), tensor(int32, Linear)
|
||||
|
@ -261,6 +261,13 @@ def multiclass_nms_static(ctx,
|
||||
pre_top_k, keep_top_k, iou_threshold,
|
||||
score_threshold, -1)
|
||||
|
||||
# retain shape info
|
||||
batch_size = boxes.size(0)
|
||||
|
||||
dets_shape = dets.shape
|
||||
label_shape = labels.shape
|
||||
dets = dets.reshape([batch_size, *dets_shape[1:]])
|
||||
labels = labels.reshape([batch_size, *label_shape[1:]])
|
||||
return dets, labels
|
||||
|
||||
|
||||
|
@ -200,6 +200,53 @@ def __gather_topk(*inputs: Sequence[torch.Tensor],
|
||||
return outputs
|
||||
|
||||
|
||||
class TRTGatherTopk(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, inds: torch.Tensor):
|
||||
"""Implement of gather topk."""
|
||||
batch_size = x.size(0)
|
||||
batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)
|
||||
return x[batch_inds, inds, ...]
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, x, inds):
|
||||
"""symbolic of gather topk."""
|
||||
out = g.op('mmdeploy::GatherTopk', x, inds, outputs=1)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.TENSORRT.value)
|
||||
def __gather_topk__trt(ctx,
|
||||
*inputs: Sequence[torch.Tensor],
|
||||
inds: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_batched: bool = True) -> Tuple[torch.Tensor]:
|
||||
"""TensorRT gather_topk."""
|
||||
_ = ctx
|
||||
if is_batched:
|
||||
index_shape = inds.shape
|
||||
index_dim = inds.dim()
|
||||
outputs = [None for _ in inputs]
|
||||
for i, x in enumerate(inputs):
|
||||
if x is None:
|
||||
continue
|
||||
out = TRTGatherTopk.apply(x, inds).to(x.dtype)
|
||||
out_shape = [*index_shape, *x.shape[index_dim:]]
|
||||
out = out.reshape(out_shape)
|
||||
outputs[i] = out
|
||||
else:
|
||||
prior_inds = inds.new_zeros((1, 1))
|
||||
outputs = [
|
||||
x[prior_inds, inds, ...] if x is not None else None for x in inputs
|
||||
]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
'mmdeploy.codebase.mmdet.deploy.utils.__gather_topk',
|
||||
backend=Backend.COREML.value)
|
||||
|
@ -2,7 +2,7 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||
multiclass_nms, pad_with_value)
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
|
||||
@ -129,14 +129,18 @@ def gfl_head__get_bbox(ctx,
|
||||
else:
|
||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=bbox_pred.device).unsqueeze(-1)
|
||||
prior_inds = batch_inds.new_zeros((1, 1))
|
||||
priors = priors[prior_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
if with_score_factors:
|
||||
score_factors = score_factors[batch_inds, topk_inds, :]
|
||||
bbox_pred, scores, score_factors = gather_topk(
|
||||
bbox_pred,
|
||||
scores,
|
||||
score_factors,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=True)
|
||||
priors = gather_topk(
|
||||
priors,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=False)
|
||||
|
||||
mlvl_valid_bboxes.append(bbox_pred)
|
||||
mlvl_valid_scores.append(scores)
|
||||
|
@ -3,7 +3,7 @@ from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from mmdeploy.codebase.mmdet import (get_post_processing_params,
|
||||
from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params,
|
||||
multiclass_nms,
|
||||
pad_with_value_if_necessary)
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
@ -37,11 +37,14 @@ def reppoints_head__points2bbox(ctx, self, pts, y_first=True):
|
||||
Use `self.moment_transfer` in `points2bbox` will cause error:
|
||||
RuntimeError: Input, output and indices must be on the current device
|
||||
"""
|
||||
moment_transfer = self.moment_transfer
|
||||
delattr(self, 'moment_transfer')
|
||||
self.moment_transfer = torch.tensor(moment_transfer.data)
|
||||
update_moment = hasattr(self, 'moment_transfer')
|
||||
if update_moment:
|
||||
moment_transfer = self.moment_transfer
|
||||
delattr(self, 'moment_transfer')
|
||||
self.moment_transfer = torch.tensor(moment_transfer.data)
|
||||
ret = ctx.origin_func(self, pts, y_first=y_first)
|
||||
self.moment_transfer = moment_transfer
|
||||
if update_moment:
|
||||
self.moment_transfer = moment_transfer
|
||||
return ret
|
||||
|
||||
|
||||
@ -131,12 +134,17 @@ def reppoints_head__get_bboxes(ctx,
|
||||
else:
|
||||
max_scores, _ = nms_pre_score[..., :-1].max(-1)
|
||||
_, topk_inds = max_scores.topk(pre_topk)
|
||||
batch_inds = torch.arange(
|
||||
batch_size, device=bbox_pred.device).unsqueeze(-1)
|
||||
prior_inds = batch_inds.new_zeros((1, 1))
|
||||
priors = priors[prior_inds, topk_inds, :]
|
||||
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
|
||||
scores = scores[batch_inds, topk_inds, :]
|
||||
bbox_pred, scores = gather_topk(
|
||||
bbox_pred,
|
||||
scores,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=True)
|
||||
priors = gather_topk(
|
||||
priors,
|
||||
inds=topk_inds,
|
||||
batch_size=batch_size,
|
||||
is_batched=False)
|
||||
|
||||
bbox_pred = _bbox_pre_decode(priors, bbox_pred,
|
||||
self.point_strides[level_idx])
|
||||
|
@ -1124,3 +1124,39 @@ def test_dot_product_attention(backend, save_dir=None):
|
||||
input_names=['query', 'key', 'value'],
|
||||
output_names=['out', 'attn'],
|
||||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
|
||||
def test_gather_topk(backend, save_dir=None):
|
||||
backend.check_env()
|
||||
from mmdeploy.codebase.mmdet.deploy.utils import gather_topk
|
||||
|
||||
x = torch.rand(2, 10, 4).cuda()
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0)
|
||||
max_x, _ = x.max(-1)
|
||||
_, inds = max_x.topk(4)
|
||||
|
||||
new_x = gather_topk(x, inds=inds, batch_size=batch_size)
|
||||
return new_x
|
||||
|
||||
model = TestModel().cuda()
|
||||
|
||||
with RewriterContext(
|
||||
Config({'backend_config': {
|
||||
'type': backend.backend_name
|
||||
}}),
|
||||
backend=backend.backend_name,
|
||||
opset=11):
|
||||
backend.run_and_validate(
|
||||
model, [x],
|
||||
'gather_topk',
|
||||
input_names=['x'],
|
||||
output_names=['out'],
|
||||
save_dir=save_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user