[Fix] Remove cudnn dependency for transform 'mmaction2::format_shape' (#1509)
* fix format shape * merge common code * use throw_exception * udpate code formatpull/1530/head
parent
ae785f42e1
commit
52fd4fe9f3
|
@ -12,64 +12,9 @@ class FormatShapeImpl : public FormatShapeOp {
|
|||
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
|
||||
|
||||
protected:
|
||||
Result<void> apply(const std::vector<Tensor>& tensors, Tensor& output, int clip_len,
|
||||
int num_clips) override {
|
||||
auto N = static_cast<int64_t>(tensors.size());
|
||||
auto H = tensors[0].shape(1);
|
||||
auto W = tensors[0].shape(2);
|
||||
auto C = tensors[0].shape(3);
|
||||
Device host_{0, 0};
|
||||
|
||||
TensorDesc desc = {kHost, DataType::kFLOAT, {N, H, W, C}};
|
||||
Tensor imgs(desc);
|
||||
auto offset = 0UL;
|
||||
auto n_item = H * W * C;
|
||||
auto copy_size = n_item * sizeof(float);
|
||||
for (int i = 0; i < N; i++) {
|
||||
auto src_buffer = tensors[i].buffer();
|
||||
auto dst_buffer = imgs.buffer();
|
||||
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
|
||||
offset += copy_size;
|
||||
}
|
||||
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
|
||||
Tensor dst;
|
||||
if (input_format_ == "NCHW") {
|
||||
OUTCOME_TRY(dst, FormatNCHW(imgs, clip_len, num_clips));
|
||||
}
|
||||
if (input_format_ == "NCTHW") {
|
||||
OUTCOME_TRY(dst, FormatNCTHW(imgs, clip_len, num_clips));
|
||||
}
|
||||
TensorShape expand_dim = dst.shape();
|
||||
expand_dim.insert(expand_dim.begin(), 1);
|
||||
dst.Reshape(expand_dim);
|
||||
output = std::move(dst);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
|
||||
};
|
||||
|
||||
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
auto L = clip_len;
|
||||
if (N % L != 0) {
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
int M = N / L;
|
||||
src.Reshape({M, L, H, W, C});
|
||||
|
||||
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
|
||||
};
|
||||
const Device& GetDevice() { return host_; }
|
||||
|
||||
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) {
|
||||
|
@ -113,8 +58,6 @@ class FormatShapeImpl : public FormatShapeOp {
|
|||
} while (i >= 0);
|
||||
return dst;
|
||||
}
|
||||
|
||||
constexpr static Device kHost{0, 0};
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {
|
||||
|
|
|
@ -4,17 +4,15 @@ if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
|||
endif ()
|
||||
|
||||
project(mmdeploy_mmaction_cuda_impl CXX)
|
||||
include(${CMAKE_SOURCE_DIR}/cmake/modules/FindCUDNN.cmake)
|
||||
|
||||
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp)
|
||||
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu)
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
|
||||
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
|
||||
endif ()
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE
|
||||
${CUDA_INCLUDE_DIRS})
|
||||
${CUDA_INCLUDE_DIRS})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||
mmdeploy::core
|
||||
cudnn)
|
||||
mmdeploy::core)
|
||||
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
|
||||
mmdeploy_export(${PROJECT_NAME})
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "cudnn.h"
|
||||
#include "cuda_runtime.h"
|
||||
#include "mmdeploy/codebase/mmaction/format_shape.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
|
||||
|
@ -8,85 +8,16 @@ using namespace std;
|
|||
|
||||
namespace mmdeploy::mmaction::cuda {
|
||||
|
||||
#define CUDNN_CHECK(condition) \
|
||||
do { \
|
||||
if (condition != CUDNN_STATUS_SUCCESS) { \
|
||||
MMDEPLOY_ERROR("cudnn error, msg = {}", cudnnGetErrorString(condition)); \
|
||||
} \
|
||||
} while (0);
|
||||
template <typename T>
|
||||
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
|
||||
int total, cudaStream_t stream);
|
||||
|
||||
class FormatShapeImpl : public FormatShapeOp {
|
||||
public:
|
||||
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {
|
||||
CUDNN_CHECK(cudnnCreate(&handle_));
|
||||
CUDNN_CHECK(cudnnSetStream(handle_, GetNative<cudaStream_t>(stream())));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&src_desc_));
|
||||
CUDNN_CHECK(cudnnCreateTensorDescriptor(&dst_desc_));
|
||||
}
|
||||
|
||||
~FormatShapeImpl() override {
|
||||
CUDNN_CHECK(cudnnDestroy(handle_));
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(src_desc_));
|
||||
CUDNN_CHECK(cudnnDestroyTensorDescriptor(dst_desc_));
|
||||
}
|
||||
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
|
||||
|
||||
protected:
|
||||
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
|
||||
int num_clips) override {
|
||||
auto N = static_cast<int64_t>(inputs.size());
|
||||
auto H = inputs[0].shape(1);
|
||||
auto W = inputs[0].shape(2);
|
||||
auto C = inputs[0].shape(3);
|
||||
|
||||
auto t0 = std::chrono::high_resolution_clock::now();
|
||||
TensorDesc desc = {device_, DataType::kFLOAT, {N, H, W, C}};
|
||||
Tensor imgs(desc);
|
||||
int offset = 0;
|
||||
int n_item = H * W * C;
|
||||
int copy_size = n_item * sizeof(float);
|
||||
for (int i = 0; i < N; i++) {
|
||||
auto src_buffer = inputs[i].buffer();
|
||||
auto dst_buffer = imgs.buffer();
|
||||
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
|
||||
offset += copy_size;
|
||||
}
|
||||
|
||||
// Tensor dst;
|
||||
if (input_format_ == "NCHW") {
|
||||
OUTCOME_TRY(output, FormatNCHW(imgs, clip_len, num_clips));
|
||||
}
|
||||
if (input_format_ == "NCTHW") {
|
||||
OUTCOME_TRY(output, FormatNCTHW(imgs, clip_len, num_clips));
|
||||
}
|
||||
TensorShape expand_dim = output.shape();
|
||||
expand_dim.insert(expand_dim.begin(), 1);
|
||||
output.Reshape(expand_dim);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
|
||||
};
|
||||
|
||||
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
int L = clip_len;
|
||||
if (N % L != 0) {
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
int M = N / L;
|
||||
src.Reshape({M, L, H, W, C});
|
||||
|
||||
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
|
||||
};
|
||||
const Device& GetDevice() { return device(); }
|
||||
|
||||
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) {
|
||||
|
@ -97,14 +28,6 @@ class FormatShapeImpl : public FormatShapeOp {
|
|||
}
|
||||
dst.Reshape(shape);
|
||||
|
||||
SetCudnnTensorDescriptor(src_dims, permutation);
|
||||
CUDNN_CHECK(cudnnTransformTensor(handle_, &one_, src_desc_, src.data<float>(), &zero_,
|
||||
dst_desc_, dst.data<float>()));
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
void SetCudnnTensorDescriptor(const TensorShape& src_dims, const std::vector<int>& permutation) {
|
||||
auto ndim = src_dims.size();
|
||||
std::vector<int> dst_dims(ndim);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
|
@ -124,17 +47,16 @@ class FormatShapeImpl : public FormatShapeOp {
|
|||
src_strides[i] = buffer[permutation[i]];
|
||||
}
|
||||
|
||||
CUDNN_CHECK(cudnnSetTensorNdDescriptor(src_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
|
||||
src_strides.data()));
|
||||
CUDNN_CHECK(cudnnSetTensorNdDescriptor(dst_desc_, CUDNN_DATA_FLOAT, ndim, dst_dims.data(),
|
||||
dst_strides.data()));
|
||||
}
|
||||
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
|
||||
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
|
||||
OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides));
|
||||
OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides));
|
||||
|
||||
constexpr static float one_{1.0};
|
||||
constexpr static float zero_{0.0};
|
||||
cudnnHandle_t handle_{};
|
||||
cudnnTensorDescriptor_t src_desc_{};
|
||||
cudnnTensorDescriptor_t dst_desc_{};
|
||||
::mmdeploy::mmaction::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides),
|
||||
dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
|
||||
src.size(), (cudaStream_t)stream().GetNative());
|
||||
return dst;
|
||||
}
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
namespace mmdeploy::mmaction::cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides,
|
||||
int ndim, int total) {
|
||||
int u = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (u >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int remaining = u;
|
||||
int v = 0;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
int p = remaining / dst_strides[i];
|
||||
remaining -= p * dst_strides[i];
|
||||
v += p * src_strides[i];
|
||||
}
|
||||
dst[u] = src[v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
|
||||
int total, cudaStream_t stream) {
|
||||
int thread_num = 256;
|
||||
int block_num = (total + thread_num - 1) / thread_num;
|
||||
transpose<T>
|
||||
<<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
|
||||
}
|
||||
|
||||
template void Transpose<float>(const float* src, const int* src_strides, float* dst,
|
||||
const int* dst_strides, int ndim, int total, cudaStream_t stream);
|
||||
|
||||
} // namespace mmdeploy::mmaction::cuda
|
|
@ -12,8 +12,76 @@ namespace mmdeploy::mmaction {
|
|||
FormatShape::FormatShape(const Value& args) {
|
||||
auto input_format = args.value("input_format", std::string(""));
|
||||
if (input_format != "NCHW" && input_format != "NCTHW") {
|
||||
throw std::domain_error("'input_format' should be 'NCHW' or 'NCTHW'");
|
||||
MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'");
|
||||
throw_exception(eInvalidArgument);
|
||||
}
|
||||
format_ = operation::Managed<mmdeploy::mmaction::FormatShapeOp>::Create(input_format);
|
||||
}
|
||||
|
||||
Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& output, int clip_len,
|
||||
int num_clips) {
|
||||
Tensor inputs;
|
||||
OUTCOME_TRY(MergeInputs(images, inputs));
|
||||
if (GetDevice().is_host()) {
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
}
|
||||
|
||||
// Tensor dst;
|
||||
if (input_format_ == "NCHW") {
|
||||
OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips));
|
||||
}
|
||||
if (input_format_ == "NCTHW") {
|
||||
OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips));
|
||||
}
|
||||
|
||||
TensorShape expand_dim = output.shape();
|
||||
expand_dim.insert(expand_dim.begin(), 1);
|
||||
output.Reshape(expand_dim);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> FormatShapeOp::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
|
||||
auto N = static_cast<int64_t>(images.size());
|
||||
auto H = images[0].shape(1);
|
||||
auto W = images[0].shape(2);
|
||||
auto C = images[0].shape(3);
|
||||
|
||||
TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}};
|
||||
inputs = Tensor(desc);
|
||||
auto offset = 0UL;
|
||||
auto n_item = H * W * C;
|
||||
auto copy_size = n_item * sizeof(float);
|
||||
for (int i = 0; i < N; i++) {
|
||||
auto src_buffer = images[i].buffer();
|
||||
auto dst_buffer = inputs.buffer();
|
||||
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
|
||||
offset += copy_size;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<Tensor> FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
|
||||
}
|
||||
|
||||
Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
int L = clip_len;
|
||||
if (N % L != 0) {
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
int M = N / L;
|
||||
src.Reshape({M, L, H, W, C});
|
||||
|
||||
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
|
||||
}
|
||||
|
||||
Result<void> FormatShape::Apply(Value& data) {
|
||||
|
@ -50,6 +118,7 @@ Result<void> FormatShape::Apply(Value& data) {
|
|||
}
|
||||
|
||||
Tensor dst;
|
||||
data = Value{};
|
||||
OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips));
|
||||
data["img"] = std::move(dst);
|
||||
|
||||
|
|
|
@ -16,8 +16,19 @@ class FormatShapeOp : public operation::Operation {
|
|||
public:
|
||||
explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){};
|
||||
|
||||
virtual Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
|
||||
int num_clips) = 0;
|
||||
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
|
||||
int num_clips);
|
||||
|
||||
virtual const Device& GetDevice() = 0;
|
||||
|
||||
virtual Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) = 0;
|
||||
|
||||
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips);
|
||||
|
||||
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips);
|
||||
|
||||
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
|
||||
|
||||
protected:
|
||||
std::string input_format_;
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace mmdeploy::transform {
|
|||
class Lift : public Transform {
|
||||
public:
|
||||
explicit Lift(const Value& args) {
|
||||
const char* type = "compose";
|
||||
const char* type = "Compose";
|
||||
if (auto creator = gRegistry<Transform>().Get(type)) {
|
||||
compose_ = creator->Create(args);
|
||||
} else {
|
||||
|
|
Loading…
Reference in New Issue