mirror of https://github.com/open-mmlab/mmcv.git
[Feature] : Add ScatterND TensorRT Plugin (#786)
* add scatter plugin * fix bugs of scatternd * add trt scatternd plugin * format code with clang-format * add test for scatternd * skip test_tensorrt in CI * remove unused variable Co-authored-by: maningsheng <maningsheng@sensetime.com>pull/796/head
parent
8e3a801596
commit
2ab544fc29
|
@ -1,8 +1,10 @@
|
|||
#include "trt_plugin.hpp"
|
||||
|
||||
#include "trt_roi_align.hpp"
|
||||
#include "trt_scatternd.hpp"
|
||||
|
||||
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
|
||||
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
|
||||
|
||||
extern "C" {
|
||||
bool initLibMMCVInferPlugins() { return true; }
|
||||
|
|
|
@ -0,0 +1,206 @@
|
|||
#include "trt_scatternd.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono>
|
||||
|
||||
#include "trt_serialize.hpp"
|
||||
|
||||
extern void TRTONNXScatterNDKernelLauncher_float(
|
||||
const float *data, const int *indices, const float *update, const int *dims,
|
||||
int nbDims, const int *indices_dims, int indice_nbDims, float *output,
|
||||
cudaStream_t stream);
|
||||
|
||||
extern void TRTONNXScatterNDKernelLauncher_int32(
|
||||
const int *data, const int *indices, const int *update, const int *dims,
|
||||
int nbDims, const int *indices_dims, int indice_nbDims, int *output,
|
||||
cudaStream_t stream);
|
||||
|
||||
namespace {
|
||||
static const char *PLUGIN_VERSION{"1"};
|
||||
static const char *PLUGIN_NAME{"ScatterND"};
|
||||
} // namespace
|
||||
|
||||
nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField>
|
||||
ONNXScatterNDDynamicCreator::mPluginAttributes;
|
||||
|
||||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name)
|
||||
: mLayerName(name) {}
|
||||
|
||||
ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name,
|
||||
const void *data, size_t length)
|
||||
: mLayerName(name) {}
|
||||
|
||||
nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const {
|
||||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) {
|
||||
return inputs[0];
|
||||
}
|
||||
|
||||
bool ONNXScatterNDDynamic::supportsFormatCombination(
|
||||
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
|
||||
int nbOutputs) {
|
||||
if (pos < nbInputs) {
|
||||
switch (pos) {
|
||||
case 0:
|
||||
// data
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
|
||||
(inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
|
||||
case 1:
|
||||
// indices
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
|
||||
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
|
||||
case 2:
|
||||
// updates
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
switch (pos - nbInputs) {
|
||||
case 0:
|
||||
// output
|
||||
return inOut[pos].type == inOut[0].type &&
|
||||
inOut[pos].format == inOut[0].format;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamic::configurePlugin(
|
||||
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}
|
||||
|
||||
size_t ONNXScatterNDDynamic::getWorkspaceSize(
|
||||
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs,
|
||||
void *const *outputs, void *workSpace,
|
||||
cudaStream_t stream) {
|
||||
const int *dims = &(inputDesc[0].dims.d[0]);
|
||||
const int *indices_dims = &(inputDesc[1].dims.d[0]);
|
||||
int nbDims = inputDesc[0].dims.nbDims;
|
||||
int indice_nbDims = inputDesc[1].dims.nbDims;
|
||||
|
||||
const void *data = inputs[0];
|
||||
const void *indices = inputs[1];
|
||||
const void *update = inputs[2];
|
||||
void *output = outputs[0];
|
||||
|
||||
auto data_type = inputDesc[0].type;
|
||||
|
||||
switch (data_type) {
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
TRTONNXScatterNDKernelLauncher_float(
|
||||
(float *)data, (int *)indices, (float *)update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, (float *)output, stream);
|
||||
break;
|
||||
|
||||
case nvinfer1::DataType::kINT32:
|
||||
TRTONNXScatterNDKernelLauncher_int32(
|
||||
(int *)data, (int *)indices, (int *)update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, (int *)output, stream);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType(
|
||||
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
|
||||
return inputTypes[0];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; }
|
||||
|
||||
const char *ONNXScatterNDDynamic::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int ONNXScatterNDDynamic::getNbOutputs() const { return 1; }
|
||||
|
||||
int ONNXScatterNDDynamic::initialize() { return 0; }
|
||||
|
||||
void ONNXScatterNDDynamic::terminate() {}
|
||||
|
||||
size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; }
|
||||
|
||||
void ONNXScatterNDDynamic::serialize(void *buffer) const {}
|
||||
|
||||
void ONNXScatterNDDynamic::destroy() {
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamic::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
||||
|
||||
////////////////////// creator /////////////////////////////
|
||||
|
||||
ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() {
|
||||
mPluginAttributes.clear();
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginName() const {
|
||||
return PLUGIN_NAME;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginVersion() const {
|
||||
return PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
const nvinfer1::PluginFieldCollection *
|
||||
ONNXScatterNDDynamicCreator::getFieldNames() {
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) {
|
||||
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin(
|
||||
const char *name, const void *serialData, size_t serialLength) {
|
||||
auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength);
|
||||
plugin->setPluginNamespace(getPluginNamespace());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) {
|
||||
mNamespace = libNamespace;
|
||||
}
|
||||
|
||||
const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const {
|
||||
return mNamespace.c_str();
|
||||
}
|
|
@ -0,0 +1,92 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "common_cuda_helper.hpp"
|
||||
#include "trt_cuda_helper.cuh"
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
static int const threadsPerBlock = sizeof(unsigned long long int) * 8;
|
||||
|
||||
using mmcv::TensorDesc;
|
||||
|
||||
template <typename T>
|
||||
__global__ void onnx_scatternd_kernel(const int n, const int* indices,
|
||||
const T* update, T* output,
|
||||
TensorDesc tensor_desc,
|
||||
TensorDesc indice_desc) {
|
||||
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
|
||||
const int copy_stride = tensor_desc.stride[indice_cols - 1];
|
||||
const int* stride = &(tensor_desc.stride[0]);
|
||||
CUDA_1D_KERNEL_LOOP(index, n) {
|
||||
int output_offset = 0;
|
||||
const int* indices_current = indices + index * indice_cols;
|
||||
for (int i = 0; i < indice_cols; ++i) {
|
||||
output_offset += stride[i] * indices_current[i];
|
||||
}
|
||||
memcpy(output + output_offset, update + index * copy_stride,
|
||||
copy_stride * sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
|
||||
const T* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, T* output,
|
||||
cudaStream_t stream) {
|
||||
// fill tensordesc and initial
|
||||
TensorDesc tensor_desc;
|
||||
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
|
||||
tensor_desc.dim = nbDims;
|
||||
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
|
||||
tensor_desc.stride[nbDims - 1] = 1;
|
||||
for (int i = nbDims - 2; i >= 0; --i) {
|
||||
tensor_desc.shape[i] = dims[i];
|
||||
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
|
||||
}
|
||||
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];
|
||||
|
||||
TensorDesc indice_desc;
|
||||
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
|
||||
indice_desc.dim = indice_nbDims;
|
||||
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
|
||||
indice_desc.stride[indice_nbDims - 1] = 1;
|
||||
for (int i = indice_nbDims - 2; i >= 0; --i) {
|
||||
indice_desc.shape[i] = indices_dims[i];
|
||||
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
|
||||
}
|
||||
|
||||
// output = np.copy(data)
|
||||
cudaMemcpyAsync(output, data, data_size * sizeof(T),
|
||||
cudaMemcpyDeviceToDevice);
|
||||
|
||||
int num_update_indice = 1;
|
||||
for (int i = 0; i < indice_nbDims - 1; ++i) {
|
||||
num_update_indice *= indice_desc.shape[i];
|
||||
}
|
||||
// scatter
|
||||
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
|
||||
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
|
||||
num_update_indice, indices, update, output, tensor_desc, indice_desc);
|
||||
}
|
||||
|
||||
void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices,
|
||||
const float* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, float* output,
|
||||
cudaStream_t stream) {
|
||||
TRTONNXScatterNDKernelLauncher<float>(data, indices, update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, output,
|
||||
stream);
|
||||
}
|
||||
|
||||
void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices,
|
||||
const int* update, const int* dims,
|
||||
int nbDims, const int* indices_dims,
|
||||
int indice_nbDims, int* output,
|
||||
cudaStream_t stream) {
|
||||
TRTONNXScatterNDKernelLauncher<int>(data, indices, update, dims, nbDims,
|
||||
indices_dims, indice_nbDims, output,
|
||||
stream);
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
#ifndef TRT_CUDA_HELPER_HPP
|
||||
#define TRT_CUDA_HELPER_HPP
|
||||
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
#define cudaCheckError() \
|
||||
{ \
|
||||
cudaError_t e = cudaGetLastError(); \
|
||||
if (e != cudaSuccess) { \
|
||||
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \
|
||||
cudaGetErrorString(e)); \
|
||||
exit(0); \
|
||||
} \
|
||||
}
|
||||
|
||||
#endif // TRT_CUDA_HELPER_HPP
|
|
@ -6,6 +6,14 @@
|
|||
|
||||
namespace mmcv {
|
||||
|
||||
const int MAXTENSORDIMS = 10;
|
||||
|
||||
struct TensorDesc {
|
||||
int shape[MAXTENSORDIMS];
|
||||
int stride[MAXTENSORDIMS];
|
||||
int dim;
|
||||
};
|
||||
|
||||
inline unsigned int getElementSize(nvinfer1::DataType t) {
|
||||
switch (t) {
|
||||
case nvinfer1::DataType::kINT32:
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
#ifndef TRT_SCATTERND_HPP
|
||||
#define TRT_SCATTERND_HPP
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "trt_plugin_helper.hpp"
|
||||
|
||||
class ONNXScatterNDDynamic : public nvinfer1::IPluginV2DynamicExt {
|
||||
public:
|
||||
ONNXScatterNDDynamic(const std::string &name);
|
||||
|
||||
ONNXScatterNDDynamic(const std::string name, const void *data, size_t length);
|
||||
|
||||
ONNXScatterNDDynamic() = delete;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt *clone() const override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(
|
||||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder &exprBuilder) override;
|
||||
bool supportsFormatCombination(int pos,
|
||||
const nvinfer1::PluginTensorDesc *inOut,
|
||||
int nbInputs, int nbOutputs) override;
|
||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in,
|
||||
int nbInputs,
|
||||
const nvinfer1::DynamicPluginTensorDesc *out,
|
||||
int nbOutputs) override;
|
||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
|
||||
int nbInputs,
|
||||
const nvinfer1::PluginTensorDesc *outputs,
|
||||
int nbOutputs) const override;
|
||||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
|
||||
const nvinfer1::PluginTensorDesc *outputDesc,
|
||||
const void *const *inputs, void *const *outputs, void *workspace,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(int index,
|
||||
const nvinfer1::DataType *inputTypes,
|
||||
int nbInputs) const override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
const char *getPluginType() const override;
|
||||
const char *getPluginVersion() const override;
|
||||
int getNbOutputs() const override;
|
||||
int initialize() override;
|
||||
void terminate() override;
|
||||
size_t getSerializationSize() const override;
|
||||
void serialize(void *buffer) const override;
|
||||
void destroy() override;
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
std::string mNamespace;
|
||||
|
||||
protected:
|
||||
// To prevent compiler warnings.
|
||||
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
|
||||
using nvinfer1::IPluginV2DynamicExt::enqueue;
|
||||
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
|
||||
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
|
||||
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
|
||||
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
|
||||
};
|
||||
|
||||
class ONNXScatterNDDynamicCreator : public nvinfer1::IPluginCreator {
|
||||
public:
|
||||
ONNXScatterNDDynamicCreator();
|
||||
|
||||
const char *getPluginName() const override;
|
||||
|
||||
const char *getPluginVersion() const override;
|
||||
|
||||
const nvinfer1::PluginFieldCollection *getFieldNames() override;
|
||||
|
||||
nvinfer1::IPluginV2 *createPlugin(
|
||||
const char *name, const nvinfer1::PluginFieldCollection *fc) override;
|
||||
|
||||
nvinfer1::IPluginV2 *deserializePlugin(const char *name,
|
||||
const void *serialData,
|
||||
size_t serialLength) override;
|
||||
|
||||
void setPluginNamespace(const char *pluginNamespace) override;
|
||||
|
||||
const char *getPluginNamespace() const override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
std::string mNamespace;
|
||||
};
|
||||
#endif // TRT_SCATTERND_HPP
|
|
@ -5,21 +5,38 @@ import onnx
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
|
||||
save_trt_engine)
|
||||
except ImportError:
|
||||
pytest.skip(
|
||||
'TensorRT should be installed from source.', allow_module_level=True)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip(
|
||||
'CUDA is required for this test module', allow_module_level=True)
|
||||
|
||||
if not is_tensorrt_plugin_loaded():
|
||||
pytest.skip(
|
||||
'Test requires to complie TensorRT plugins in mmcv',
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
class WrapFunction(torch.nn.Module):
|
||||
|
||||
def __init__(self, wrapped_function):
|
||||
super(WrapFunction, self).__init__()
|
||||
self.wrapped_function = wrapped_function
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.wrapped_function(*args, **kwargs)
|
||||
|
||||
|
||||
onnx_file = 'tmp.onnx'
|
||||
trt_file = 'tmp.engine'
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='CUDA is required for test_roialign')
|
||||
def test_roialign():
|
||||
try:
|
||||
from mmcv.tensorrt import (TRTWraper, onnx2trt, save_trt_engine,
|
||||
is_tensorrt_plugin_loaded)
|
||||
if not is_tensorrt_plugin_loaded():
|
||||
pytest.skip('test requires to complie TensorRT plugins in mmcv')
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
pytest.skip('test requires to install TensorRT from source.')
|
||||
|
||||
try:
|
||||
from mmcv.ops import RoIAlign
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
|
@ -91,3 +108,63 @@ def test_roialign():
|
|||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
assert torch.allclose(pytorch_roi_feat, trt_roi_feat)
|
||||
|
||||
|
||||
def test_scatternd():
|
||||
|
||||
def func(data):
|
||||
data[:, :-2] += 1
|
||||
data[:2, :] -= 1
|
||||
return data
|
||||
|
||||
data = torch.zeros(4, 4).cuda()
|
||||
wrapped_model = WrapFunction(func).eval().cuda()
|
||||
|
||||
input_names = ['input']
|
||||
output_names = ['output']
|
||||
|
||||
with torch.no_grad():
|
||||
torch.onnx.export(
|
||||
wrapped_model, (data.clone(), ),
|
||||
onnx_file,
|
||||
export_params=True,
|
||||
keep_initializers_as_inputs=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
opset_version=11)
|
||||
|
||||
onnx_model = onnx.load(onnx_file)
|
||||
|
||||
# create trt engine and wraper
|
||||
opt_shape_dict = {
|
||||
'input': [list(data.shape),
|
||||
list(data.shape),
|
||||
list(data.shape)],
|
||||
}
|
||||
# trt config
|
||||
fp16_mode = False
|
||||
max_workspace_size = 1 << 30
|
||||
|
||||
trt_engine = onnx2trt(
|
||||
onnx_model,
|
||||
opt_shape_dict,
|
||||
fp16_mode=fp16_mode,
|
||||
max_workspace_size=max_workspace_size)
|
||||
|
||||
save_trt_engine(trt_engine, trt_file)
|
||||
trt_model = TRTWraper(trt_file, input_names, output_names)
|
||||
|
||||
with torch.no_grad():
|
||||
trt_outputs = trt_model({'input': data.clone()})
|
||||
trt_results = trt_outputs['output']
|
||||
|
||||
# compute pytorch_output
|
||||
with torch.no_grad():
|
||||
pytorch_results = wrapped_model(data.clone())
|
||||
|
||||
# allclose
|
||||
if os.path.exists(onnx_file):
|
||||
os.remove(onnx_file)
|
||||
if os.path.exists(trt_file):
|
||||
os.remove(trt_file)
|
||||
assert torch.allclose(pytorch_results, trt_results)
|
||||
|
|
Loading…
Reference in New Issue