mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* check in cmake * move backend_ops to csrc/backend_ops * check in preprocess, model, some codebase and their c-apis * check in CMakeLists.txt * check in parts of test_csrc * commit everything else * add readme * update core's BUILD_INTERFACE directory * skip codespell on third_party * update trt_net and ort_net's CMakeLists * ignore clion's build directory * check in pybind11 * add onnx.proto. Remove MMDeploy's dependency on ncnn's source code * export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON * remove useless message * target include directory is wrong * change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net * skip install directory * update project's cmake * remove useless code * set CMAKE_BUILD_TYPE to Release by force if it isn't set by user * update custom ops CMakeLists * pass object target's source lists * fix lint end-of-file * fix lint: trailing whitespace * fix codespell hook * remove bicubic_interpolate to csrc/backend_ops/ * set MMDEPLOY_BUILD_SDK OFF * change custom ops build command * add spdlog installation command * update docs on how to checkout pybind11 * move bicubic_interpolate to backend_ops/tensorrt directory * remove useless code * correct cmake * fix typo * fix typo * fix install directory * correct sdk's readme * set cub dir when cuda version < 11.0 * change directory where clang-format will apply to * fix build command * add .clang-format * change clang-format style from google to file * reformat csrc/backend_ops * format sdk's code * turn off clang-format for some files * add -Xcompiler=-fno-gnu-unique * fix trt topk initialize * check in config for sdk demo * update cmake script and csrc's readme * correct config's path * add cuda include directory, otherwise compile failed in case of tensorrt8.2 * clang-format onnx2ncnn.cpp Co-authored-by: zhangli <lzhang329@gmail.com> Co-authored-by: grimoire <yaoqian@sensetime.com>
50 lines
1.8 KiB
Plaintext
50 lines
1.8 KiB
Plaintext
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#include <stdint.h>
|
|
|
|
namespace mmdeploy {
|
|
namespace cuda {
|
|
|
|
template <typename T, int channels>
|
|
__global__ void crop(const T *src, int src_w, T *dst, int dst_h, int dst_w, int offset_h,
|
|
int offset_w) {
|
|
int x = blockIdx.x * blockDim.x + threadIdx.x;
|
|
int y = blockIdx.y * blockDim.y + threadIdx.y;
|
|
|
|
if (x >= dst_w || y >= dst_h) return;
|
|
int src_x = x + offset_w;
|
|
int src_y = y + offset_h;
|
|
|
|
int dst_loc = (y * dst_w + x) * channels;
|
|
int src_loc = (src_y * src_w + src_x) * channels;
|
|
|
|
for (int i = 0; i < channels; ++i) {
|
|
dst[dst_loc + i] = src[src_loc + i];
|
|
}
|
|
}
|
|
|
|
template <typename T, int channels>
|
|
void Crop(const T *src, int src_w, T *dst, int dst_h, int dst_w, int offset_h, int offset_w,
|
|
cudaStream_t stream) {
|
|
const dim3 thread_block(32, 8);
|
|
const dim3 block_num((dst_w + thread_block.x - 1) / thread_block.x,
|
|
(dst_h + thread_block.y - 1) / thread_block.y);
|
|
crop<T, channels>
|
|
<<<block_num, thread_block, 0, stream>>>(src, src_w, dst, dst_h, dst_w, offset_h, offset_w);
|
|
}
|
|
|
|
template void Crop<uint8_t, 3>(const uint8_t *src, int src_w, uint8_t *dst, int dst_h, int dst_w,
|
|
int offset_h, int offset_w, cudaStream_t stream);
|
|
|
|
template void Crop<uint8_t, 1>(const uint8_t *src, int src_w, uint8_t *dst, int dst_h, int dst_w,
|
|
int offset_h, int offset_w, cudaStream_t stream);
|
|
|
|
template void Crop<float, 3>(const float *src, int src_w, float *dst, int dst_h, int dst_w,
|
|
int offset_h, int offset_w, cudaStream_t stream);
|
|
|
|
template void Crop<float, 1>(const float *src, int src_w, float *dst, int dst_h, int dst_w,
|
|
int offset_h, int offset_w, cudaStream_t stream);
|
|
|
|
} // namespace cuda
|
|
} // namespace mmdeploy
|