diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp index 313a37dc0..ba5823f55 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp @@ -119,10 +119,17 @@ int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation results_ptr->height = segmentor_output.height; results_ptr->width = segmentor_output.width; results_ptr->classes = segmentor_output.classes; - auto mask_size = results_ptr->height * results_ptr->width; auto& mask = segmentor_output.mask; - results_ptr->mask = mask.data(); - buffers[i] = mask.buffer(); + auto& score = segmentor_output.score; + results_ptr->mask = nullptr; + results_ptr->score = nullptr; + if (mask.shape().size()) { + results_ptr->mask = mask.data(); + buffers[i] = mask.buffer(); + } else { + results_ptr->score = score.data(); + buffers[i] = score.buffer(); + } } *results = results_data; diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h index 7ae77a03f..65bcfd03f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h @@ -17,11 +17,14 @@ extern "C" { #endif typedef struct mmdeploy_segmentation_t { - int height; ///< height of \p mask that equals to the input image's height - int width; ///< width of \p mask that equals to the input image's width - int classes; ///< the number of labels in \p mask - int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates - ///< the label id of pixel at (i, j) + int height; ///< height of \p mask that equals to the input image's height + int width; ///< width of \p mask that equals to the input image's width + int classes; ///< the number of labels in \p mask + int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates + ///< the label id of pixel at (i, j), this field might be null + float* score; ///< segmentation score map of the input image in CHW format, in which + ///< score[height * width * k + i * width + j] indicates the score + ///< of class k at pixel (i, j), this field might be null } mmdeploy_segmentation_t; typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; diff --git a/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs b/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs index 6470a2d8a..c3b75ca60 100644 --- a/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs +++ b/csrc/mmdeploy/apis/csharp/MMDeploy/APIs/Segmentor.cs @@ -10,6 +10,7 @@ namespace MMDeploy public int Width; public int Classes; public int* Mask; + public float* Score; } #pragma warning restore 0649 @@ -34,10 +35,16 @@ namespace MMDeploy public int Classes; /// - /// Mask data. + /// Mask data, mask[i * width + j] indicates the label id of pixel at (i, j). /// public int[] Mask; + /// + /// Score data, score[height * width * k + i * width + j] indicates the score + /// of class k at pixel (i, j). + /// + public float[] Score; + /// /// Initializes a new instance of the struct. /// @@ -45,13 +52,31 @@ namespace MMDeploy /// width. /// classes. /// mask. - public SegmentorOutput(int height, int width, int classes, int[] mask) + /// score. + public SegmentorOutput(int height, int width, int classes, int[] mask, float[] score) { Height = height; Width = width; Classes = classes; - Mask = new int[Height * Width]; - Array.Copy(mask, this.Mask, mask.Length); + if (mask.Length > 0) + { + Mask = new int[Height * Width]; + Array.Copy(mask, this.Mask, mask.Length); + } + else + { + Mask = new int[] { }; + } + + if (score.Length > 0) + { + Score = new float[Height * Width * Classes]; + Array.Copy(score, this.Score, score.Length); + } + else + { + Score = new float[] { }; + } } internal unsafe SegmentorOutput(CSegment* result) @@ -59,11 +84,34 @@ namespace MMDeploy Height = result->Height; Width = result->Width; Classes = result->Classes; - Mask = new int[Height * Width]; - int nbytes = Height * Width * sizeof(int); - fixed (int* data = this.Mask) + if (result->Mask != null) { - Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes); + Mask = new int[Height * Width]; + + int nbytes = Height * Width * sizeof(int); + fixed (int* data = this.Mask) + { + Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes); + } + } + else + { + Mask = new int[] { }; + } + + if (result->Score != null) + { + Score = new float[Height * Width * Classes]; + + int nbytes = Height * Width * Classes * sizeof(float); + fixed (float* data = this.Score) + { + Buffer.MemoryCopy(result->Score, data, nbytes, nbytes); + } + } + else + { + Score = new float[] { }; } } } diff --git a/csrc/mmdeploy/apis/python/segmentor.cpp b/csrc/mmdeploy/apis/python/segmentor.cpp index 1fdf719fc..940972ab6 100644 --- a/csrc/mmdeploy/apis/python/segmentor.cpp +++ b/csrc/mmdeploy/apis/python/segmentor.cpp @@ -37,12 +37,22 @@ class PySegmentor { std::vector rets(mats.size()); for (size_t i = 0; i < mats.size(); ++i) { - rets[i] = { - {segm[i].height, segm[i].width}, // shape - segm[i].mask, // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; + if (segm[i].mask != nullptr) { + rets[i] = { + {segm[i].height, segm[i].width}, // shape + segm[i].mask, // mask + py::capsule(new Sptr(holder), // handle + [](void* p) { delete reinterpret_cast(p); }) // + }; + } + if (segm[i].score != nullptr) { + rets[i] = { + {segm[i].classes, segm[i].height, segm[i].width}, // shape + segm[i].score, // score + py::capsule(new Sptr(holder), // handle + [](void* p) { delete reinterpret_cast(p); }) // + }; + } } return rets; } diff --git a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt index 7aa25aebe..2ea41f727 100644 --- a/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmaction/CMakeLists.txt @@ -4,8 +4,7 @@ project(mmdeploy_mmaction) file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") -add_subdirectory(cpu) -add_subdirectory(cuda) + target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_operation mmdeploy_transform diff --git a/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt deleted file mode 100644 index c81552839..000000000 --- a/csrc/mmdeploy/codebase/mmaction/cpu/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - -project(mmdeploy_mmaction_cpu_impl CXX) - -if ("cpu" IN_LIST MMDEPLOY_TARGET_DEVICES) - add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp) - set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) - if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC)) - target_compile_options(${PROJECT_NAME} PRIVATE $<$:-fvisibility=hidden>) - endif () - target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy::core) - target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME}) - mmdeploy_export(${PROJECT_NAME}) -endif () diff --git a/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp b/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp deleted file mode 100644 index 0a6900cbf..000000000 --- a/csrc/mmdeploy/codebase/mmaction/cpu/format_shape_impl.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "mmdeploy/codebase/mmaction/format_shape.h" -#include "mmdeploy/core/utils/device_utils.h" - -using namespace std; - -namespace mmdeploy::mmaction::cpu { - -class FormatShapeImpl : public FormatShapeOp { - public: - explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {} - - protected: - Device host_{0, 0}; - - const Device& GetDevice() { return host_; } - - Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) { - Tensor dst(src.desc()); - TensorShape shape(src.shape().size()); - for (int i = 0; i < shape.size(); i++) { - shape[i] = src.shape(permutation[i]); - } - dst.Reshape(shape); - int ndim = shape.size(); - std::vector dst_strides(ndim); - std::vector src_strides(ndim); - dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; - for (int i = ndim - 2; i >= 0; i--) { - dst_strides[i] = dst_strides[i + 1] * shape[i + 1]; - src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; - } - std::vector tmp(ndim); - for (int i = 0; i < ndim; i++) { - tmp[i] = src_strides[permutation[i]]; - } - src_strides.swap(tmp); - std::vector coord(ndim, 0); - auto dst_data = dst.data(); - auto src_data = src.data(); - - int i; - do { - dst_data[0] = src_data[0]; - for (i = ndim - 1; i >= 0; i--) { - if (++coord[i] == shape[i]) { - coord[i] = 0; - dst_data -= (shape[i] - 1) * dst_strides[i]; - src_data -= (shape[i] - 1) * src_strides[i]; - } else { - dst_data += dst_strides[i]; - src_data += src_strides[i]; - break; - } - } - } while (i >= 0); - return dst; - } -}; - -MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) { - return std::make_unique(std::move(input_format)); -}); - -} // namespace mmdeploy::mmaction::cpu diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt b/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt deleted file mode 100644 index 9502a3396..000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES) - return() -endif () - -project(mmdeploy_mmaction_cuda_impl CXX) - -add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu) -set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1) -if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC)) - target_compile_options(${PROJECT_NAME} PRIVATE $<$:-fvisibility=hidden>) -endif () -target_include_directories(${PROJECT_NAME} PRIVATE - ${CUDA_INCLUDE_DIRS}) -target_link_libraries(${PROJECT_NAME} PRIVATE - mmdeploy::core) -target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME}) -mmdeploy_export(${PROJECT_NAME}) diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp b/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp deleted file mode 100644 index 391b5afc6..000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/format_shape_impl.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include "cuda_runtime.h" -#include "mmdeploy/codebase/mmaction/format_shape.h" -#include "mmdeploy/core/utils/device_utils.h" - -using namespace std; - -namespace mmdeploy::mmaction::cuda { - -template -void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim, - int total, cudaStream_t stream); - -class FormatShapeImpl : public FormatShapeOp { - public: - explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {} - - protected: - const Device& GetDevice() { return device(); } - - Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) { - Tensor dst(src.desc()); - TensorShape shape(src.shape().size()); - for (int i = 0; i < shape.size(); i++) { - shape[i] = src.shape(permutation[i]); - } - dst.Reshape(shape); - - auto ndim = src_dims.size(); - std::vector dst_dims(ndim); - for (int i = 0; i < ndim; i++) { - dst_dims[i] = src_dims[permutation[i]]; - } - - std::vector src_strides(ndim); - std::vector dst_strides(ndim); - std::vector buffer(ndim); - buffer.back() = 1; - dst_strides.back() = 1; - for (int i = ndim - 1; i > 0; i--) { - buffer[i - 1] = buffer[i] * src_dims[i]; - dst_strides[i - 1] = dst_strides[i] * dst_dims[i]; - } - for (int i = 0; i < ndim; ++i) { - src_strides[i] = buffer[permutation[i]]; - } - - Buffer _src_strides(Device("cuda"), sizeof(int) * ndim); - Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim); - OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides)); - OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides)); - - ::mmdeploy::mmaction::cuda::Transpose(src.data(), GetNative(_src_strides), - dst.data(), GetNative(_dst_strides), ndim, - src.size(), (cudaStream_t)stream().GetNative()); - return dst; - } -}; - -MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) { - return std::make_unique(std::move(input_format)); -}); - -} // namespace mmdeploy::mmaction::cuda diff --git a/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu b/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu deleted file mode 100644 index bef1d447e..000000000 --- a/csrc/mmdeploy/codebase/mmaction/cuda/transpose.cu +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. - -#include -#include - -namespace mmdeploy { -namespace mmaction { -namespace cuda { - -template -__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, - int ndim, int total) { - int u = blockIdx.x * blockDim.x + threadIdx.x; - if (u >= total) { - return; - } - - int remaining = u; - int v = 0; - for (int i = 0; i < ndim; i++) { - int p = remaining / dst_strides[i]; - remaining -= p * dst_strides[i]; - v += p * src_strides[i]; - } - dst[u] = src[v]; -} - -template -void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim, - int total, cudaStream_t stream) { - int thread_num = 256; - int block_num = (total + thread_num - 1) / thread_num; - transpose - <<>>(src, src_strides, dst, dst_strides, ndim, total); -} - -template void Transpose(const float* src, const int* src_strides, float* dst, - const int* dst_strides, int ndim, int total, cudaStream_t stream); - -} // namespace cuda -} // namespace mmaction -} // namespace mmdeploy diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp index 81d9ac478..7d8c6ac5c 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp @@ -10,28 +10,47 @@ using namespace std; namespace mmdeploy::mmaction { FormatShape::FormatShape(const Value& args) { - auto input_format = args.value("input_format", std::string("")); - if (input_format != "NCHW" && input_format != "NCTHW") { + input_format_ = args.value("input_format", std::string("")); + if (input_format_ != "NCHW" && input_format_ != "NCTHW") { MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); throw_exception(eInvalidArgument); } - format_ = operation::Managed::Create(input_format); + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } -Result FormatShapeOp::apply(const std::vector& images, Tensor& output, int clip_len, - int num_clips) { +Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) { + auto N = static_cast(images.size()); + auto H = images[0].shape(1); + auto W = images[0].shape(2); + auto C = images[0].shape(3); + auto& device = operation::gContext().device(); + auto& stream = operation::gContext().stream(); + + TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; + inputs = Tensor(desc); + auto offset = 0UL; + auto n_item = H * W * C; + auto copy_size = n_item * sizeof(float); + for (int i = 0; i < N; i++) { + auto src_buffer = images[i].buffer(); + auto dst_buffer = inputs.buffer(); + OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); + offset += copy_size; + } + return success(); +} + +Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, + int num_clips) { Tensor inputs; OUTCOME_TRY(MergeInputs(images, inputs)); - if (GetDevice().is_host()) { - OUTCOME_TRY(stream().Wait()); - } // Tensor dst; if (input_format_ == "NCHW") { - OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips)); + OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); } if (input_format_ == "NCTHW") { - OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips)); + OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); } TensorShape expand_dim = output.shape(); @@ -41,35 +60,13 @@ Result FormatShapeOp::apply(const std::vector& images, Tensor& out return success(); } -Result FormatShapeOp::MergeInputs(const std::vector& images, Tensor& inputs) { - auto N = static_cast(images.size()); - auto H = images[0].shape(1); - auto W = images[0].shape(2); - auto C = images[0].shape(3); - - TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}}; - inputs = Tensor(desc); - auto offset = 0UL; - auto n_item = H * W * C; - auto copy_size = n_item * sizeof(float); - for (int i = 0; i < N; i++) { - auto src_buffer = images[i].buffer(); - auto dst_buffer = inputs.buffer(); - OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset)); - offset += copy_size; - } +Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { + const vector axes = {0, 3, 1, 2}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); return success(); } -Result FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) { - auto N = src.shape(0); - auto H = src.shape(1); - auto W = src.shape(2); - auto C = src.shape(3); - return Transpose(src, {N, H, W, C}, {0, 3, 1, 2}); -} - -Result FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) { +Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { auto N = src.shape(0); auto H = src.shape(1); auto W = src.shape(2); @@ -80,8 +77,9 @@ Result FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_cli } int M = N / L; src.Reshape({M, L, H, W, C}); - - return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3}); + const vector axes = {0, 4, 1, 2, 3}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); } Result FormatShape::Apply(Value& data) { @@ -119,7 +117,7 @@ Result FormatShape::Apply(Value& data) { Tensor dst; data = Value{}; - OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips)); + OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); data["img"] = std::move(dst); return success(); @@ -127,6 +125,4 @@ Result FormatShape::Apply(Value& data) { MMDEPLOY_REGISTER_TRANSFORM(FormatShape); -MMDEPLOY_DEFINE_REGISTRY(FormatShapeOp); - } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.h b/csrc/mmdeploy/codebase/mmaction/format_shape.h index 13b664878..97e4f9935 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.h +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.h @@ -1,50 +1,38 @@ // Copyright (c) OpenMMLab. All rights reserved. -#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_ -#define MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_ +#ifndef MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_ +#define MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_ #include +#include #include #include "mmdeploy/core/tensor.h" #include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" namespace mmdeploy::mmaction { -class FormatShapeOp : public operation::Operation { - public: - explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){}; - - Result apply(const std::vector& inputs, Tensor& output, int clip_len, - int num_clips); - - virtual const Device& GetDevice() = 0; - - virtual Result Transpose(Tensor& src, const TensorShape& src_dims, - const std::vector& permutation) = 0; - - Result FormatNCHW(Tensor& src, int clip_len, int num_clips); - - Result FormatNCTHW(Tensor& src, int clip_len, int num_clips); - - Result MergeInputs(const std::vector& images, Tensor& inputs); - - protected: - std::string input_format_; -}; - class FormatShape : public Transform { public: explicit FormatShape(const Value& args); Result Apply(Value& data) override; - private: - operation::Managed format_; -}; + Result Format(const std::vector& images, Tensor& output, int clip_len, + int num_clips); -MMDEPLOY_DECLARE_REGISTRY(FormatShapeOp, std::unique_ptr(std::string input_format)); + Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + + Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + + Result MergeInputs(const std::vector& images, Tensor& inputs); + + private: + std::string input_format_; + operation::Managed permute_; +}; } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.h b/csrc/mmdeploy/codebase/mmaction/mmaction.h index 238b78096..ef097e6f2 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.h +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.h @@ -1,7 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. -#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_ -#define MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_ +#ifndef MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_ +#define MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_ #include "mmdeploy/codebase/common.h" #include "mmdeploy/core/device.h" diff --git a/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt b/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt index fc05b1e20..aac237634 100644 --- a/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt +++ b/csrc/mmdeploy/codebase/mmseg/CMakeLists.txt @@ -4,7 +4,9 @@ project(mmdeploy_mmseg) file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp") mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") -target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils) +target_link_libraries(${PROJECT_NAME} PRIVATE + mmdeploy_opencv_utils + mmdeploy_operation) add_library(mmdeploy::mmseg ALIAS ${PROJECT_NAME}) set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} segmentor CACHE INTERNAL "") diff --git a/csrc/mmdeploy/codebase/mmseg/mmseg.h b/csrc/mmdeploy/codebase/mmseg/mmseg.h index 8a9c0f6d7..8f55fadce 100644 --- a/csrc/mmdeploy/codebase/mmseg/mmseg.h +++ b/csrc/mmdeploy/codebase/mmseg/mmseg.h @@ -12,10 +12,11 @@ namespace mmdeploy::mmseg { struct SegmentorOutput { Tensor mask; + Tensor score; int height; int width; int classes; - MMDEPLOY_ARCHIVE_MEMBERS(mask, height, width, classes); + MMDEPLOY_ARCHIVE_MEMBERS(mask, score, height, width, classes); }; MMDEPLOY_DECLARE_CODEBASE(MMSegmentation, mmseg); diff --git a/csrc/mmdeploy/codebase/mmseg/segment.cpp b/csrc/mmdeploy/codebase/mmseg/segment.cpp index b1128886c..56811a4fa 100644 --- a/csrc/mmdeploy/codebase/mmseg/segment.cpp +++ b/csrc/mmdeploy/codebase/mmseg/segment.cpp @@ -5,6 +5,8 @@ #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/utils/device_utils.h" #include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" #include "opencv_utils.h" @@ -18,7 +20,10 @@ class ResizeMask : public MMSegmentation { explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) { try { classes_ = cfg["params"]["num_classes"].get(); + with_argmax_ = cfg["params"].value("with_argmax", true); little_endian_ = IsLittleEndian(); + ::mmdeploy::operation::Context ctx(Device("cpu"), stream_); + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } catch (const std::exception &e) { MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg); throw_exception(eInvalidArgument); @@ -31,40 +36,71 @@ class ResizeMask : public MMSegmentation { auto mask = inference_result["output"].get(); MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(), mask.data_type()); - if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) { + if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) { MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape()); return Status(eNotSupported); } + if ((mask.shape(1) != 1) && with_argmax_) { + MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`", + mask.shape()); + return Status(eNotSupported); + } + if ((mask.data_type() != DataType::kFLOAT) && !with_argmax_) { + MMDEPLOY_ERROR("probability feat map only support float32 output"); + return Status(eNotSupported); + } + auto channel = (int)mask.shape(1); auto height = (int)mask.shape(2); auto width = (int)mask.shape(3); auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get(); auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get(); Device host{"cpu"}; OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_)); - OUTCOME_TRY(stream_.Wait()); + if (!with_argmax_) { + // (C, H, W) -> (H, W, C) + ::mmdeploy::operation::Context ctx(host, stream_); + std::vector axes = {0, 2, 3, 1}; + OUTCOME_TRY(permute_.Apply(host_tensor, host_tensor, axes)); + } - OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type())); + OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type(), channel)); cv::Mat mask_mat(height, width, cv_type, host_tensor.data()); - if (mask_mat.channels() > 1) { - cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1); - } - if (mask_mat.type() != CV_32S) { - mask_mat.convertTo(mask_mat, CV_32S); + cv::Mat resized_mask; + cv::Mat resized_score; + + Tensor tensor_mask{}; + Tensor tensor_score{}; + + if (with_argmax_) { + // mask + if (mask_mat.channels() > 1) { + cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1); + } + if (mask_mat.type() != CV_32S) { + mask_mat.convertTo(mask_mat, CV_32S); + } + resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest"); + tensor_mask = cpu::CVMat2Tensor(resized_mask); + } else { + // score + resized_score = cpu::Resize(mask_mat, input_height, input_width, "bilinear"); + tensor_score = cpu::CVMat2Tensor(resized_score); + std::vector axes = {0, 3, 1, 2}; + ::mmdeploy::operation::Context ctx(host, stream_); + OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes)); } - cv::Mat resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest"); - - SegmentorOutput output{cpu::CVMat2Tensor(resized_mask), input_height, input_width, classes_}; + SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_}; return to_value(output); } private: - static Result GetCvType(DataType type) { + static Result GetCvType(DataType type, int channel) { switch (type) { case DataType::kFLOAT: - return CV_32F; + return CV_32FC(channel); case DataType::kINT64: return CV_32SC2; case DataType::kINT32: @@ -84,7 +120,9 @@ class ResizeMask : public MMSegmentation { } protected: + ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_; int classes_{}; + bool with_argmax_{true}; bool little_endian_; }; diff --git a/csrc/mmdeploy/operation/cpu/CMakeLists.txt b/csrc/mmdeploy/operation/cpu/CMakeLists.txt index fa6a1de56..5607123de 100644 --- a/csrc/mmdeploy/operation/cpu/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cpu/CMakeLists.txt @@ -11,7 +11,8 @@ set(SRCS resize.cpp crop.cpp flip.cpp warp_affine.cpp - crop_resize_pad.cpp) + crop_resize_pad.cpp + permute.cpp) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cpu/permute.cpp b/csrc/mmdeploy/operation/cpu/permute.cpp new file mode 100644 index 000000000..44c98fe24 --- /dev/null +++ b/csrc/mmdeploy/operation/cpu/permute.cpp @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/operation/vision.h" +#include "mmdeploy/utils/opencv/opencv_utils.h" + +namespace mmdeploy::operation::cpu { + +class PermuteImpl : public Permute { + public: + explicit PermuteImpl() {} + + Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) override { + int ndim = src.shape().size(); + if (ndim != axes.size()) { + MMDEPLOY_ERROR("The size of axes should be equal to src, {} vs {}", axes.size(), ndim); + return Status(eInvalidArgument); + } + std::vector axes_vis(ndim, 0); + for (const auto& x : axes) { + if (x < 0 || x >= ndim || axes_vis[x]) { + MMDEPLOY_ERROR("Invalid axes"); + return Status(eInvalidArgument); + } + axes_vis[x] = 1; + } + + Tensor dst_tensor(src.desc()); + auto src_dims = src.shape(); + TensorShape dst_dims(ndim); + for (int i = 0; i < src_dims.size(); i++) { + dst_dims[i] = src_dims[axes[i]]; + } + dst_tensor.Reshape(dst_dims); + + std::vector dst_strides(ndim); + std::vector src_strides(ndim); + dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; i--) { + dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1]; + src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; + } + + std::vector tmp(ndim); + for (int i = 0; i < ndim; i++) { + tmp[i] = src_strides[axes[i]]; + } + src_strides.swap(tmp); + + if (src.data_type() == DataType::kINT8) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else if (src.data_type() == DataType::kFLOAT) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else { + MMDEPLOY_ERROR("unsupported data type {}", src.data_type()); + return Status(eNotSupported); + } + dst = std::move(dst_tensor); + return success(); + } + + template + Result PermuteDispatch(const Tensor& src, Tensor& dst, const std::vector& src_strides, + const std::vector& dst_strides) { + auto shape = dst.shape(); + int ndim = src.shape().size(); + std::vector coord(ndim, 0); + auto dst_data = dst.data(); + auto src_data = src.data(); + + int i; + do { + dst_data[0] = src_data[0]; + for (i = ndim - 1; i >= 0; i--) { + if (++coord[i] == shape[i]) { + coord[i] = 0; + dst_data -= (shape[i] - 1) * dst_strides[i]; + src_data -= (shape[i] - 1) * src_strides[i]; + } else { + dst_data += dst_strides[i]; + src_data += src_strides[i]; + break; + } + } + } while (i >= 0); + return success(); + } +}; + +MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cpu, 0), []() { return std::make_unique(); }); + +} // namespace mmdeploy::operation::cpu diff --git a/csrc/mmdeploy/operation/cuda/CMakeLists.txt b/csrc/mmdeploy/operation/cuda/CMakeLists.txt index 8322b9d2a..551f89977 100644 --- a/csrc/mmdeploy/operation/cuda/CMakeLists.txt +++ b/csrc/mmdeploy/operation/cuda/CMakeLists.txt @@ -19,7 +19,9 @@ set(SRCS resize.cpp crop.cu flip.cpp warp_affine.cpp - crop_resize_pad.cpp) + crop_resize_pad.cpp + permute.cpp + permute.cu) mmdeploy_add_module(${PROJECT_NAME} "${SRCS}") diff --git a/csrc/mmdeploy/operation/cuda/permute.cpp b/csrc/mmdeploy/operation/cuda/permute.cpp new file mode 100644 index 000000000..c5c87da88 --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.cpp @@ -0,0 +1,91 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "mmdeploy/operation/cuda/permute.h" + +#include + +#include "mmdeploy/operation/vision.h" + +namespace mmdeploy::operation::cuda { + +namespace impl { +template +void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides, + int ndim, int total, cudaStream_t stream); +} + +class PermuteImpl : public Permute { + public: + explicit PermuteImpl() {} + + Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) override { + int ndim = src.shape().size(); + if (ndim != axes.size()) { + MMDEPLOY_ERROR("The size of axes should be equal of src, {} vs {}", axes.size(), ndim); + return Status(eInvalidArgument); + } + if (ndim > MAX_PERMUTE_DIM) { + MMDEPLOY_ERROR("Only support ndim < {}", MAX_PERMUTE_DIM); + return Status(eInvalidArgument); + } + std::vector axes_vis(ndim, 0); + for (const auto& x : axes) { + if (x < 0 || x >= ndim || axes_vis[x]) { + MMDEPLOY_ERROR("Invalid axes"); + return Status(eInvalidArgument); + } + axes_vis[x] = 1; + } + + Tensor dst_tensor(src.desc()); + auto src_dims = src.shape(); + TensorShape dst_dims(ndim); + for (int i = 0; i < src_dims.size(); i++) { + dst_dims[i] = src_dims[axes[i]]; + } + dst_tensor.Reshape(dst_dims); + + TensorStride dst_strides; + TensorStride src_strides; + + dst_strides[ndim - 1] = src_strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; i--) { + dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1]; + src_strides[i] = src_strides[i + 1] * src_dims[i + 1]; + } + + TensorStride tmp; + for (int i = 0; i < ndim; i++) { + tmp[i] = src_strides[axes[i]]; + } + src_strides = tmp; + + if (src.data_type() == DataType::kINT8) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else if (src.data_type() == DataType::kFLOAT) { + OUTCOME_TRY(PermuteDispatch(src, dst_tensor, src_strides, dst_strides)); + } else { + MMDEPLOY_ERROR("unsupported data type {}", src.data_type()); + return Status(eNotSupported); + } + dst = std::move(dst_tensor); + return success(); + } + + template + Result PermuteDispatch(const Tensor& src, Tensor& dst, const TensorStride& src_strides, + const TensorStride& dst_strides) { + auto src_data = src.data(); + auto dst_data = dst.data(); + auto ndim = src.shape().size(); + auto total = src.size(); + impl::Permute(src_data, src_strides, dst_data, dst_strides, ndim, total, + GetNative(stream())); + return success(); + } +}; + +MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cuda, 0), + []() { return std::make_unique(); }); + +} // namespace mmdeploy::operation::cuda diff --git a/csrc/mmdeploy/operation/cuda/permute.cu b/csrc/mmdeploy/operation/cuda/permute.cu new file mode 100644 index 000000000..7f979ed3f --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.cu @@ -0,0 +1,49 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "mmdeploy/operation/cuda/permute.h" + +namespace mmdeploy { +namespace operation { +namespace cuda { +namespace impl { + +template +__global__ void permute(const T* src, const TensorStride src_strides, T* dst, + const TensorStride dst_strides, int ndim, int total) { + int u = blockIdx.x * blockDim.x + threadIdx.x; + if (u >= total) { + return; + } + + int remaining = u; + int v = 0; + for (size_t i = 0; i < ndim; i++) { + int p = remaining / dst_strides.v_[i]; + remaining -= p * dst_strides.v_[i]; + v += p * src_strides.v_[i]; + } + dst[u] = src[v]; +} + +template +void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides, + int ndim, int total, cudaStream_t stream) { + int thread_num = 256; + int block_num = (total + thread_num - 1) / thread_num; + permute<<>>(src, src_strides, dst, dst_strides, ndim, total); +} + +template void Permute(const float* src, const TensorStride& src_strides, float* dst, + const TensorStride& dst_strides, int ndim, int total, + cudaStream_t stream); + +template void Permute(const uint8_t* src, const TensorStride& src_strides, uint8_t* dst, + const TensorStride& dst_strides, int ndim, int total, + cudaStream_t stream); + +} // namespace impl +} // namespace cuda +} // namespace operation +} // namespace mmdeploy diff --git a/csrc/mmdeploy/operation/cuda/permute.h b/csrc/mmdeploy/operation/cuda/permute.h new file mode 100644 index 000000000..7bbc0a404 --- /dev/null +++ b/csrc/mmdeploy/operation/cuda/permute.h @@ -0,0 +1,24 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ +#define MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ + +#include + +#include + +namespace mmdeploy { +namespace operation { +namespace cuda { + +const int MAX_PERMUTE_DIM = 8; + +struct TensorStride { + int v_[MAX_PERMUTE_DIM]; + int& operator[](size_t idx) { return v_[idx]; } +}; + +} // namespace cuda +} // namespace operation +} // namespace mmdeploy + +#endif // MMDEPLOY_OPERATION_CUDA_PERMUTE_H_ diff --git a/csrc/mmdeploy/operation/vision.cpp b/csrc/mmdeploy/operation/vision.cpp index 18694f06a..0c0b13eb9 100644 --- a/csrc/mmdeploy/operation/vision.cpp +++ b/csrc/mmdeploy/operation/vision.cpp @@ -14,5 +14,6 @@ MMDEPLOY_DEFINE_REGISTRY(Crop); MMDEPLOY_DEFINE_REGISTRY(Flip); MMDEPLOY_DEFINE_REGISTRY(WarpAffine); MMDEPLOY_DEFINE_REGISTRY(CropResizePad); +MMDEPLOY_DEFINE_REGISTRY(Permute); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/operation/vision.h b/csrc/mmdeploy/operation/vision.h index 10e699fed..013c3852b 100644 --- a/csrc/mmdeploy/operation/vision.h +++ b/csrc/mmdeploy/operation/vision.h @@ -92,6 +92,11 @@ class CropResizePad : public Operation { }; MMDEPLOY_DECLARE_REGISTRY(CropResizePad, unique_ptr()); +class Permute : public Operation { + public: + virtual Result apply(const Tensor& src, Tensor& dst, const std::vector& axes) = 0; +}; +MMDEPLOY_DECLARE_REGISTRY(Permute, unique_ptr()); } // namespace mmdeploy::operation diff --git a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp index d3cd1ad87..4b4cb3a9c 100644 --- a/csrc/mmdeploy/utils/opencv/opencv_utils.cpp +++ b/csrc/mmdeploy/utils/opencv/opencv_utils.cpp @@ -126,7 +126,7 @@ cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width, const std::string& interpolation) { cv::Mat dst(dst_height, dst_width, src.type()); auto method = GetInterpolationMethod(interpolation).value(); - cv::resize(src, dst, dst.size(), method); + cv::resize(src, dst, dst.size(), 0, 0, method); return dst; } diff --git a/demo/csharp/image_segmentation/Program.cs b/demo/csharp/image_segmentation/Program.cs index 1318c001f..b2cfa3414 100644 --- a/demo/csharp/image_segmentation/Program.cs +++ b/demo/csharp/image_segmentation/Program.cs @@ -75,21 +75,50 @@ namespace image_segmentation unsafe { byte* data = colorMask.DataPointer; - fixed (int* _label = output[0].Mask) + if (output[0].Mask.Length > 0) { - int* label = _label; - for (int i = 0; i < output[0].Height; i++) + fixed (int* _label = output[0].Mask) { - for (int j = 0; j < output[0].Width; j++) + int* label = _label; + for (int i = 0; i < output[0].Height; i++) { - data[0] = palette[*label][0]; - data[1] = palette[*label][1]; - data[2] = palette[*label][2]; - data += 3; - label++; + for (int j = 0; j < output[0].Width; j++) + { + data[0] = palette[*label][0]; + data[1] = palette[*label][1]; + data[2] = palette[*label][2]; + data += 3; + label++; + } } } } + else + { + int pos = 0; + fixed (float* _score = output[0].Score) + { + float *score = _score; + int total = output[0].Height * output[0].Width; + for (int i = 0; i < output[0].Height; i++) + { + for (int j = 0; j < output[0].Width; j++) + { + List> scores = new List>(); + for (int k = 0; k < output[0].Classes; k++) + { + scores.Add(new Tuple(score[k * total + i * output[0].Width + j], k)); + } + scores.Sort(); + data[0] = palette[scores[^1].Item2][0]; + data[1] = palette[scores[^1].Item2][1]; + data[2] = palette[scores[^1].Item2][2]; + data += 3; + } + } + } + } + } colorMask = imgs[0] * 0.5 + colorMask * 0.5; diff --git a/demo/csrc/c/image_segmentation.cpp b/demo/csrc/c/image_segmentation.cpp index fae446b4f..df26d1585 100644 --- a/demo/csrc/c/image_segmentation.cpp +++ b/demo/csrc/c/image_segmentation.cpp @@ -1,6 +1,7 @@ // Copyright (c) OpenMMLab. All rights reserved. #include +#include #include #include #include @@ -59,8 +60,25 @@ int main(int argc, char* argv[]) { cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3); int pos = 0; + int total = color_mask.rows * color_mask.cols; + std::vector idxs(result->classes); for (auto iter = color_mask.begin(); iter != color_mask.end(); ++iter) { - *iter = palette[result->mask[pos++]]; + // output mask + if (result->mask) { + *iter = palette[result->mask[pos++]]; + } + // output score + if (result->score) { + std::iota(idxs.begin(), idxs.end(), 0); + auto k = + std::max_element(idxs.begin(), idxs.end(), + [&](int i, int j) { + return result->score[i * total + pos] < result->score[j * total + pos]; + }) - + idxs.begin(); + *iter = palette[k]; + pos += 1; + } } img = img * 0.5 + color_mask * 0.5; diff --git a/demo/csrc/cpp/segmentor.cxx b/demo/csrc/cpp/segmentor.cxx index 0c1dde49d..be5100f6a 100644 --- a/demo/csrc/cpp/segmentor.cxx +++ b/demo/csrc/cpp/segmentor.cxx @@ -3,6 +3,7 @@ #include "mmdeploy/segmentor.hpp" #include +#include #include #include #include @@ -47,8 +48,25 @@ int main(int argc, char* argv[]) { cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3); int pos = 0; + int total = color_mask.rows * color_mask.cols; + std::vector idxs(result->classes); for (auto iter = color_mask.begin(); iter != color_mask.end(); ++iter) { - *iter = palette[result->mask[pos++]]; + // output mask + if (result->mask) { + *iter = palette[result->mask[pos++]]; + } + // output score + if (result->score) { + std::iota(idxs.begin(), idxs.end(), 0); + auto k = + std::max_element(idxs.begin(), idxs.end(), + [&](int i, int j) { + return result->score[pos + i * total] < result->score[pos + j * total]; + }) - + idxs.begin(); + *iter = palette[k]; + pos += 1; + } } img = img * 0.5 + color_mask * 0.5; diff --git a/demo/python/image_segmentation.py b/demo/python/image_segmentation.py index 32391f434..e70b08891 100644 --- a/demo/python/image_segmentation.py +++ b/demo/python/image_segmentation.py @@ -35,6 +35,8 @@ def main(): segmentor = Segmentor( model_path=args.model_path, device_name=args.device_name, device_id=0) seg = segmentor(img) + if seg.dtype == np.float32: + seg = np.argmax(seg, axis=0) palette = get_palette() color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) diff --git a/tests/test_csrc/preprocess/test_permute.cpp b/tests/test_csrc/preprocess/test_permute.cpp new file mode 100644 index 000000000..a12a4aad6 --- /dev/null +++ b/tests/test_csrc/preprocess/test_permute.cpp @@ -0,0 +1,121 @@ + +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "catch.hpp" +#include "mmdeploy/core/mat.h" +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "mmdeploy/operation/managed.h" +#include "mmdeploy/operation/vision.h" +#include "mmdeploy/preprocess/transform/transform.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace framework; +using namespace std; +using namespace mmdeploy::test; + +template +bool CheckEqual(const Tensor& res, const vector& expected) { + auto r = res.data(); + auto e = expected.data(); + for (int i = 0; i < expected.size(); i++) { + if (r[i] != e[i]) { + return false; + } + } + return true; +} + +template +void TestPermute(const Tensor& src, const vector& axes, const vector& expected) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + ::mmdeploy::operation::Context ctx(device, stream); + auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); + Tensor dst; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(!ret.has_error()); + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(dst, kHost, stream); + REQUIRE(CheckEqual(host_tensor.value(), expected)); + } +} + +void TestPermuteWrongArgs(const Tensor& src) { + int sz = src.shape().size(); + vector oaxes(sz); + std::iota(oaxes.begin(), oaxes.end(), 0); + + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + ::mmdeploy::operation::Context ctx(device, stream); + auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); + Tensor dst; + { + auto axes = oaxes; + axes[0]--; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + { + auto axes = oaxes; + axes.back()++; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + { + auto axes = oaxes; + axes[0] = axes[1]; + auto ret = permute.Apply(src, dst, axes); + REQUIRE(ret.has_error()); + } + } +} + +TEST_CASE("operation Permute", "[permute]") { + const Device kHost{"cpu"}; + const int kSize = 2 * 3 * 2 * 4; + vector data(kSize); + std::iota(data.begin(), data.end(), 0); // [0, 48) + TensorDesc desc = {kHost, DataType::kINT8, {kSize}}; + Tensor tensor(desc); + memcpy(tensor.data(), data.data(), data.size() * sizeof(uint8_t)); + + SECTION("permute: wrong axes") { + Tensor src = tensor; + src.Reshape({6, 8}); + TestPermuteWrongArgs(src); + } + + SECTION("permute: dims 4") { + Tensor src = tensor; + src.Reshape({2, 3, 2, 4}); + vector axes = {1, 0, 3, 2}; + vector expected = {0, 4, 1, 5, 2, 6, 3, 7, 24, 28, 25, 29, 26, 30, 27, 31, + 8, 12, 9, 13, 10, 14, 11, 15, 32, 36, 33, 37, 34, 38, 35, 39, + 16, 20, 17, 21, 18, 22, 19, 23, 40, 44, 41, 45, 42, 46, 43, 47}; + Tensor dst(src.desc()); + memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t)); + TestPermute(src, axes, expected); + } + + SECTION("permute: dims 5") { + Tensor src = tensor; + src.Reshape({2, 3, 1, 2, 4}); + vector axes = {2, 0, 1, 4, 3}; + vector expected = {0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15, + 16, 20, 17, 21, 18, 22, 19, 23, 24, 28, 25, 29, 26, 30, 27, 31, + 32, 36, 33, 37, 34, 38, 35, 39, 40, 44, 41, 45, 42, 46, 43, 47}; + Tensor dst(src.desc()); + memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t)); + TestPermute(src, axes, expected); + } +}