diff --git a/backend_ops/tensorrt/CMakeLists.txt b/backend_ops/tensorrt/CMakeLists.txt index 5efee2d4a..395a1595f 100644 --- a/backend_ops/tensorrt/CMakeLists.txt +++ b/backend_ops/tensorrt/CMakeLists.txt @@ -53,7 +53,8 @@ set(PLUGIN_LISTS scatternd roi_align batched_nms instance_norm - multi_level_roi_align) + multi_level_roi_align + grid_sampler) foreach(PLUGIN_ITER ${PLUGIN_LISTS}) file(GLOB PLUGIN_OPS_SRCS ${PLUGIN_ITER}/*.cpp ${PLUGIN_ITER}/*.cu) diff --git a/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp new file mode 100644 index 000000000..a97ef8031 --- /dev/null +++ b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp @@ -0,0 +1,222 @@ +#include "trt_grid_sampler.hpp" + +#include + +#include + +#include "trt_grid_sampler_kernel.hpp" +#include "trt_plugin_helper.hpp" +#include "trt_serialize.hpp" + +namespace mmlab { +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *PLUGIN_NAME{"grid_sampler"}; +} // namespace + +TRTGridSampler::TRTGridSampler(const std::string &name, int mode, + int paddingMode, bool alignCorners) + : TRTPluginBase(name), + mMode(mode), + mPaddingMode(paddingMode), + mAlignCorners(alignCorners) {} + +TRTGridSampler::TRTGridSampler(const std::string name, const void *data, + size_t length) + : TRTPluginBase(name) { + deserialize_value(&data, &length, &mMode); + deserialize_value(&data, &length, &mPaddingMode); + deserialize_value(&data, &length, &mAlignCorners); +} + +nvinfer1::IPluginV2DynamicExt *TRTGridSampler::clone() const TRT_NOEXCEPT { + TRTGridSampler *plugin = + new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + for (int i = 2; i < ret.nbDims; ++i) { + ret.d[i] = inputs[1].d[i - 1]; + } + return ret; +} + +bool TRTGridSampler::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) TRT_NOEXCEPT { + if (pos == 0) { + return (inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR); + } else { + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + } +} + +void TRTGridSampler::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, + int nbOutputs) TRT_NOEXCEPT { + // Validate input arguments +} + +size_t TRTGridSampler::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT { + return 0; +} + +int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, + void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { + nvinfer1::Dims input_dims = inputDesc[0].dims; + nvinfer1::Dims grid_dims = inputDesc[1].dims; + nvinfer1::Dims output_dims = outputDesc[0].dims; + + GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; + switch (mMode) { + case 0: + interp_mode = GridSamplerInterpolation::Bilinear; + break; + case 1: + interp_mode = GridSamplerInterpolation::Nearest; + break; + default: + break; + } + + GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; + switch (mPaddingMode) { + case 0: + padding_mode = GridSamplerPadding::Zeros; + break; + + case 1: + padding_mode = GridSamplerPadding::Border; + break; + + case 2: + padding_mode = GridSamplerPadding::Reflection; + break; + default: + break; + } + + auto data_type = inputDesc[0].type; + + switch (data_type) { + case nvinfer1::DataType::kFLOAT: + grid_sample( + (float *)outputs[0], (float *)inputs[0], (float *)inputs[1], + &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), + input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); + break; + default: + return 1; + break; + } + + return 0; +} + +nvinfer1::DataType TRTGridSampler::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *TRTGridSampler::getPluginType() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT { return 1; } + +size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT { + return serialized_size(mMode) + serialized_size(mPaddingMode) + + serialized_size(mAlignCorners); +} + +void TRTGridSampler::serialize(void *buffer) const TRT_NOEXCEPT { + serialize_value(&buffer, mMode); + serialize_value(&buffer, mPaddingMode); + serialize_value(&buffer, mAlignCorners); +} + +////////////////////// creator ///////////////////////////// + +TRTGridSamplerCreator::TRTGridSamplerCreator() { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("interpolation_mode"), + nvinfer1::PluginField("padding_mode"), + nvinfer1::PluginField("align_corners")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT { + return PLUGIN_NAME; +} + +const char *TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT { + return PLUGIN_VERSION; +} + +nvinfer1::IPluginV2 *TRTGridSamplerCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { + int mode = 0; + int paddingMode = 0; + bool alignCorners = false; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("interpolation_mode") == 0) { + mode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("padding_mode") == 0) { + paddingMode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("align_corners") == 0) { + alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); + } + } + + TRTGridSampler *plugin = + new TRTGridSampler(name, mode, paddingMode, alignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *TRTGridSamplerCreator::deserializePlugin( + const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + auto plugin = new TRTGridSampler(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); +} // namespace mmlab diff --git a/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp new file mode 100644 index 000000000..7f8af3d81 --- /dev/null +++ b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp @@ -0,0 +1,92 @@ +#ifndef TRT_GRID_SAMPLER_HPP +#define TRT_GRID_SAMPLER_HPP +#include + +#include +#include +#include + +#include "trt_plugin_base.hpp" + +namespace mmlab { + +class TRTGridSampler : public TRTPluginBase { + public: + TRTGridSampler(const std::string &name, int mode, int paddingMode, + bool alignCorners); + + TRTGridSampler(const std::string name, const void *data, size_t length); + + TRTGridSampler() = delete; + + ~TRTGridSampler() TRT_NOEXCEPT override = default; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT override; + + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) TRT_NOEXCEPT override; + + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const TRT_NOEXCEPT override; + + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char *getPluginType() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + + int getNbOutputs() const TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override; + + void serialize(void *buffer) const TRT_NOEXCEPT override; + + private: + int mMode; + int mPaddingMode; + bool mAlignCorners; +}; + +class TRTGridSamplerCreator : public TRTPluginCreatorBase { + public: + TRTGridSamplerCreator(); + + ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; + + const char *getPluginName() const TRT_NOEXCEPT override; + + const char *getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *createPlugin(const char *name, + const nvinfer1::PluginFieldCollection *fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2 *deserializePlugin( + const char *name, const void *serialData, + size_t serialLength) TRT_NOEXCEPT override; +}; +} // namespace mmlab +#endif // TRT_GRID_SAMPLER_HPP diff --git a/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu new file mode 100644 index 000000000..09a6b5494 --- /dev/null +++ b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu @@ -0,0 +1,435 @@ +// modified from +// https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cuh +// and +// https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cu + +#include +#include + +#include +#include +#include + +#include "common_cuda_helper.hpp" +#include "trt_grid_sampler_kernel.hpp" +#include "trt_plugin_helper.hpp" + +using mmlab::TensorDesc; + +// Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, +// where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). +// if align_corners: -1 and +1 get sent to the centers of the corner pixels +// -1 --> 0 +// +1 --> (size - 1) +// scale_factor = (size - 1) / 2 +// if not align_corners: -1 and +1 get sent to the image edges +// -1 --> -0.5 +// +1 --> (size - 1) + 0.5 == size - 0.5 +// scale_factor = size / 2 +template +static __forceinline__ __device__ scalar_t +grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) { + if (align_corners) { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } else { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } +} + +// Clips coordinates to between 0 and clip_limit - 1 +template +static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, + int clip_limit) { + return ::min(static_cast(clip_limit - 1), + ::max(in, static_cast(0))); +} + +// Reflects coordinates until they fall between low and high (inclusive). +// The bounds are passed as twice their value so that half-integer values +// can be represented as ints. +template +static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, + int twice_low, + int twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) { + return extra + min; + } else { + return span - extra + min; + } +} + +template +static __forceinline__ __device__ scalar_t +safe_downgrade_to_int_range(scalar_t x) { + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; +} + +// Computes the pixel source index value for a grid coordinate +template +static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index( + scalar_t coord, int size, GridSamplerPadding padding_mode, + bool align_corners) { + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } else if (padding_mode == GridSamplerPadding::Reflection) { + // reflect coordinates by image borders + if (align_corners) { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + } else { + coord = reflect_coordinates(coord, -1, 2 * size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + + coord = safe_downgrade_to_int_range(coord); + return coord; +} + +static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, + int W) { + return h >= 0 && h < H && w >= 0 && w < W; +} + +static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, + int D, int H, int W) { + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +} + +template +__global__ void grid_sampler_2d_kernel( + const int nthreads, const scalar_t *input, const scalar_t *grid, + scalar_t *output, TensorDesc input_desc, TensorDesc grid_desc, + TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, bool align_corners) { + int C = input_desc.shape[1]; + int inp_H = input_desc.shape[2]; + int inp_W = input_desc.shape[3]; + int out_H = grid_desc.shape[1]; + int out_W = grid_desc.shape[2]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sH = input_desc.stride[2]; + int inp_sW = input_desc.stride[3]; + int grid_sN = grid_desc.stride[0]; + int grid_sH = grid_desc.stride[1]; + int grid_sW = grid_desc.stride[2]; + int grid_sCoor = grid_desc.stride[3]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sH = output_desc.stride[2]; + int out_sW = output_desc.stride[3]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int n = index / (out_H * out_W); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y co-ordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, + align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, + align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = static_cast(::floor(ix)); + int iy_nw = static_cast(::floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + *out_ptr_NCHW = static_cast(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } + } +} + +template +__global__ void grid_sampler_3d_kernel( + const int nthreads, const scalar_t *input, const scalar_t *grid, + scalar_t *output, TensorDesc input_desc, TensorDesc grid_desc, + TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, + const GridSamplerPadding padding_mode, bool align_corners) { + int C = input_desc.shape[1]; + int inp_D = input_desc.shape[2]; + int inp_H = input_desc.shape[3]; + int inp_W = input_desc.shape[4]; + int out_D = grid_desc.shape[1]; + int out_H = grid_desc.shape[2]; + int out_W = grid_desc.shape[3]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sD = input_desc.stride[2]; + int inp_sH = input_desc.stride[3]; + int inp_sW = input_desc.stride[4]; + int grid_sN = grid_desc.stride[0]; + int grid_sD = grid_desc.stride[1]; + int grid_sH = grid_desc.stride[2]; + int grid_sW = grid_desc.stride[3]; + int grid_sCoor = grid_desc.stride[4]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sD = output_desc.stride[2]; + int out_sH = output_desc.stride[3]; + int out_sW = output_desc.stride[4]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int d = (index / (out_H * out_W)) % out_D; + const int n = index / (out_D * out_H * out_W); + const int grid_offset = + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z co-ordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, + align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, + align_corners); + iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, + align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = + output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * + // tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * + // tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * + // bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * + // bse + *out_ptr_NCDHW = static_cast(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * + tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * + tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * + tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * + tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * + bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * + bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * + bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * + bse; + } + } + } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + int iz_nearest = static_cast(::round(iz)); + + // assign nearest neighor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = + output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, + inp_W)) { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +void create_desc(const int *dims, int nb_dims, TensorDesc &desc) { + memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); + desc.stride[nb_dims - 1] = 1; + for (int i = nb_dims - 2; i >= 0; --i) { + desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; + } +} + +template +void grid_sample(T *output, const T *input, const T *grid, int *output_dims, + int *input_dims, int *grid_dims, int nb_dims, + GridSamplerInterpolation interp, GridSamplerPadding padding, + bool align_corners, cudaStream_t stream) { + TensorDesc input_desc; + create_desc(input_dims, nb_dims, input_desc); + + TensorDesc output_desc; + create_desc(output_dims, nb_dims, output_desc); + + TensorDesc grid_desc; + create_desc(grid_dims, nb_dims, grid_desc); + + int count = 1; + for (int i = 0; i < nb_dims; ++i) { + if (i == 1) { + continue; + } + count *= output_desc.shape[i]; + } + + if (nb_dims == 4) { + grid_sampler_2d_kernel + <<>>( + count, input, grid, output, input_desc, grid_desc, output_desc, + interp, padding, align_corners); + } else if (nb_dims == 5) { + grid_sampler_3d_kernel + <<>>( + count, input, grid, output, input_desc, grid_desc, output_desc, + interp, padding, align_corners); + } else { + printf("input and grid dims should be 4 or 5\n"); + } +} + +template void grid_sample(float *output, const float *input, + const float *grid, int *output_dims, + int *input_dims, int *grid_dims, int nb_dims, + GridSamplerInterpolation interp, + GridSamplerPadding padding, bool align_corners, + cudaStream_t stream); diff --git a/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp new file mode 100644 index 000000000..bb1902460 --- /dev/null +++ b/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp @@ -0,0 +1,13 @@ +#ifndef TRT_GRID_SAMPLER_KERNEL_HPP +#define TRT_GRID_SAMPLER_KERNEL_HPP +#include + +enum class GridSamplerInterpolation { Bilinear, Nearest }; +enum class GridSamplerPadding { Zeros, Border, Reflection }; + +template +void grid_sample(T *output, const T *input, const T *grid, int *output_dims, + int *input_dims, int *grid_dims, int nb_dims, + GridSamplerInterpolation interp, GridSamplerPadding padding, + bool align_corners, cudaStream_t stream); +#endif // TRT_GRID_SAMPLER_KERNEL_HPP diff --git a/configs/mmdet/base.py b/configs/mmdet/base.py index 4ed18329d..ca2e76fd7 100644 --- a/configs/mmdet/base.py +++ b/configs/mmdet/base.py @@ -3,12 +3,23 @@ codebase = 'mmdet' pytorch2onnx = dict( input_names=['input'], output_names=['dets', 'labels'], - dynamic_axes={'input': { - 0: 'batch', - 2: 'height', - 3: 'width' - }}, + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + }, ) + post_processing = dict( score_threshold=0.05, iou_threshold=0.5, diff --git a/configs/mmdet/mask_base.py b/configs/mmdet/mask_base.py new file mode 100644 index 000000000..e6766ed3c --- /dev/null +++ b/configs/mmdet/mask_base.py @@ -0,0 +1,25 @@ +_base_ = ['./base.py'] +pytorch2onnx = dict( + output_names=['dets', 'labels', 'masks'], + dynamic_axes={ + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'dets': { + 0: 'batch', + 1: 'num_dets', + }, + 'labels': { + 0: 'batch', + 1: 'num_dets', + }, + 'masks': { + 0: 'batch', + 1: 'num_dets', + 2: 'height', + 3: 'width' + }, + }, +) diff --git a/configs/mmdet/mask_onnxruntime.py b/configs/mmdet/mask_onnxruntime.py new file mode 100644 index 000000000..262046b96 --- /dev/null +++ b/configs/mmdet/mask_onnxruntime.py @@ -0,0 +1 @@ +_base_ = ['./mask_base.py', '../_base_/backends/onnxruntime.py'] diff --git a/configs/mmdet/mask_tensorrt.py b/configs/mmdet/mask_tensorrt.py new file mode 100644 index 000000000..e2a44fe60 --- /dev/null +++ b/configs/mmdet/mask_tensorrt.py @@ -0,0 +1 @@ +_base_ = ['./mask_base.py', './tensorrt_base.py'] diff --git a/configs/mmdet/tensorrt.py b/configs/mmdet/tensorrt.py index 60e3c2ea9..58a314153 100644 --- a/configs/mmdet/tensorrt.py +++ b/configs/mmdet/tensorrt.py @@ -1,7 +1 @@ -_base_ = ['./base.py', '../_base_/backends/tensorrt.py'] -tensorrt_params = dict(model_params=[ - dict( - opt_shape_dict=dict( - input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]), - max_workspace_size=1 << 30) -]) +_base_ = ['./base.py', './tensorrt_base.py'] diff --git a/configs/mmdet/tensorrt_base.py b/configs/mmdet/tensorrt_base.py new file mode 100644 index 000000000..968b7741b --- /dev/null +++ b/configs/mmdet/tensorrt_base.py @@ -0,0 +1,7 @@ +_base_ = ['../_base_/backends/tensorrt.py'] +tensorrt_params = dict(model_params=[ + dict( + opt_shape_dict=dict( + input=[[1, 3, 320, 320], [1, 3, 800, 1344], [1, 3, 1344, 1344]]), + max_workspace_size=1 << 30) +]) diff --git a/mmdeploy/mmdet/export/model_wrappers.py b/mmdeploy/mmdet/export/model_wrappers.py index 5a7c6a54e..943ccb663 100644 --- a/mmdeploy/mmdet/export/model_wrappers.py +++ b/mmdeploy/mmdet/export/model_wrappers.py @@ -153,14 +153,10 @@ class TensorRTDetector(DeployBaseDetector): except (ImportError, ModuleNotFoundError): warnings.warn('If input model has custom plugins, \ you may have to build backend ops with TensorRT') - - output_names = ['dets', 'labels'] - model = TRTWrapper(engine_file) - if model.output_names == 3: - output_names.append('masks') - self.model = model - self.output_names = output_names - self.with_mask_output = len(output_names) == 3 + self.model = TRTWrapper(engine_file) + self.output_names = ['dets', 'labels'] + if len(self.model.output_names) == 3: + self.output_names.append('masks') def forward_test(self, imgs, *args, **kwargs): input_data = imgs[0].contiguous() @@ -168,4 +164,13 @@ class TensorRTDetector(DeployBaseDetector): outputs = self.model({'input': input_data}) outputs = [outputs[name] for name in self.output_names] outputs = [out.detach().cpu().numpy() for out in outputs] + # filtered out invalid output filled with -1 + batch_labels = outputs[1] + batch_size = batch_labels.shape[0] + inds = batch_labels.reshape(-1) != -1 + for i in range(len(outputs)): + ori_shape = outputs[i].shape + outputs[i] = outputs[i].reshape(-1, + *ori_shape[2:])[inds, ...].reshape( + batch_size, -1, *ori_shape[2:]) return outputs diff --git a/mmdeploy/mmdet/models/detectors/base.py b/mmdeploy/mmdet/models/detectors/base.py index e480ce2fb..386e45cd1 100644 --- a/mmdeploy/mmdet/models/detectors/base.py +++ b/mmdeploy/mmdet/models/detectors/base.py @@ -1,12 +1,11 @@ import torch -from mmdeploy.core import FUNCTION_REWRITER, mark +from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import is_dynamic_shape @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.BaseDetector.forward') -@mark('detector_forward', inputs='input', outputs=['dets', 'labels']) def forward_of_base_detector(ctx, self, img, img_metas=None, **kwargs): if img_metas is None: img_metas = {} diff --git a/mmdeploy/mmdet/models/roi_heads/__init__.py b/mmdeploy/mmdet/models/roi_heads/__init__.py index 0b2bd6eb6..6245218db 100644 --- a/mmdeploy/mmdet/models/roi_heads/__init__.py +++ b/mmdeploy/mmdet/models/roi_heads/__init__.py @@ -1,4 +1,5 @@ from .bbox_heads import * # noqa: F401, F403 +from .mask_heads import * # noqa: F401, F403 from .roi_extractors import * # noqa: F401, F403 from .standard_roi_head import * # noqa: F401, F403 from .test_mixins import * # noqa: F401, F403 diff --git a/mmdeploy/mmdet/models/roi_heads/mask_heads/__init__.py b/mmdeploy/mmdet/models/roi_heads/mask_heads/__init__.py new file mode 100644 index 000000000..06a5a3009 --- /dev/null +++ b/mmdeploy/mmdet/models/roi_heads/mask_heads/__init__.py @@ -0,0 +1,3 @@ +from .fcn_mask_head import get_seg_masks_of_fcn_mask_head + +__all__ = ['get_seg_masks_of_fcn_mask_head'] diff --git a/mmdeploy/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdeploy/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py new file mode 100644 index 000000000..97779e88e --- /dev/null +++ b/mmdeploy/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py @@ -0,0 +1,99 @@ +import torch +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.roi_heads.FCNMaskHead.get_seg_masks') +def get_seg_masks_of_fcn_mask_head(ctx, self, mask_pred, det_bboxes, + det_labels, rcnn_test_cfg, ori_shape, + **kwargs): + """Get segmentation masks from mask_pred and bboxes. + + Args: + mask_pred (Tensor): shape (n, #class, h, w). + det_bboxes (Tensor): shape (n, 4/5) + det_labels (Tensor): shape (n, ) + rcnn_test_cfg (dict): rcnn testing config + ori_shape (Tuple): original image height and width, shape (2,) + + Returns: + Tensor: a mask of shape (N, img_h, img_w). + """ + backend = ctx.cfg.get('backend', 'default') + mask_pred = mask_pred.sigmoid() + bboxes = det_bboxes[:, :4] + labels = det_labels + threshold = rcnn_test_cfg.mask_thr_binary + if not self.class_agnostic: + box_inds = torch.arange(mask_pred.shape[0], device=bboxes.device) + mask_pred = mask_pred[box_inds, labels][:, None] + masks, _ = _do_paste_mask( + mask_pred, bboxes, ori_shape[0], ori_shape[1], skip_empty=False) + if backend == 'tensorrt': + return masks + if threshold >= 0: + masks = (masks >= threshold).to(dtype=torch.bool) + return masks + + +def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True): + """Paste instance masks according to boxes. + + This implementation is modified from + https://github.com/facebookresearch/detectron2/ + + Args: + masks (Tensor): N, 1, H, W + boxes (Tensor): N, 4 + img_h (int): Height of the image to be pasted. + img_w (int): Width of the image to be pasted. + skip_empty (bool): Only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + tuple: (Tensor, tuple). The first item is mask tensor, the second one + is the slice object. + If skip_empty == False, the whole image will be pasted. It will + return a mask of shape (N, img_h, img_w) and an empty tuple. + If skip_empty == True, only area around the mask will be pasted. + A mask of shape (N, h', w') and its start and end coordinates + in the original image will be returned. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + if skip_empty: + x0_int, y0_int = torch.clamp( + boxes.min(dim=0).values.floor()[:2] - 1, + min=0).to(dtype=torch.int32) + x1_int = torch.clamp( + boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp( + boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device).to(torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device).to(torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + img_masks = F.grid_sample( + masks.to(dtype=torch.float32), grid, align_corners=False) + + if skip_empty: + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () diff --git a/mmdeploy/mmdet/models/roi_heads/test_mixins.py b/mmdeploy/mmdet/models/roi_heads/test_mixins.py index 4c1731f16..9e78bc433 100644 --- a/mmdeploy/mmdet/models/roi_heads/test_mixins.py +++ b/mmdeploy/mmdet/models/roi_heads/test_mixins.py @@ -43,8 +43,6 @@ def simple_test_mask_of_mask_test_mixin(ctx, self, x, img_metas, det_bboxes, has not been executed this time' batch_size = det_bboxes.size(0) - # if det_bboxes is rescaled to the original image size, we need to - # rescale it back to the testing scale to obtain RoIs. det_bboxes = det_bboxes[..., :4] batch_index = torch.arange( det_bboxes.size(0),