[Feature] Support feature map output for mmsegmentation (#1625)
* add feature map output for mmseg
* update api
* update demo
* fix return
* update format_shape
* fix lint
* update csharp demo
* update python demo && api
* fix coreml build
* fix lint
* better sort
* update
* update cpp demo & add missing header
* change to CHW
* update csharp api
* update isort version to 5.12.0
* fix python api
* fix log
* more detail api docs
* isort support python3.7
* remove isort change
* remove whitespace
* axes check
* remove FormatShapeImpl
* minor
* add permute tc
* remove stride buffer
(cherry picked from commit b85f34141b
)
pull/1726/head
parent
a9f8d8c951
commit
1f56eea807
csrc/mmdeploy
apis
utils/opencv
demo
csharp/image_segmentation
python
tests/test_csrc/preprocess
|
@ -119,10 +119,17 @@ int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation
|
|||
results_ptr->height = segmentor_output.height;
|
||||
results_ptr->width = segmentor_output.width;
|
||||
results_ptr->classes = segmentor_output.classes;
|
||||
auto mask_size = results_ptr->height * results_ptr->width;
|
||||
auto& mask = segmentor_output.mask;
|
||||
results_ptr->mask = mask.data<int>();
|
||||
buffers[i] = mask.buffer();
|
||||
auto& score = segmentor_output.score;
|
||||
results_ptr->mask = nullptr;
|
||||
results_ptr->score = nullptr;
|
||||
if (mask.shape().size()) {
|
||||
results_ptr->mask = mask.data<int>();
|
||||
buffers[i] = mask.buffer();
|
||||
} else {
|
||||
results_ptr->score = score.data<float>();
|
||||
buffers[i] = score.buffer();
|
||||
}
|
||||
}
|
||||
|
||||
*results = results_data;
|
||||
|
|
|
@ -17,11 +17,14 @@ extern "C" {
|
|||
#endif
|
||||
|
||||
typedef struct mmdeploy_segmentation_t {
|
||||
int height; ///< height of \p mask that equals to the input image's height
|
||||
int width; ///< width of \p mask that equals to the input image's width
|
||||
int classes; ///< the number of labels in \p mask
|
||||
int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates
|
||||
///< the label id of pixel at (i, j)
|
||||
int height; ///< height of \p mask that equals to the input image's height
|
||||
int width; ///< width of \p mask that equals to the input image's width
|
||||
int classes; ///< the number of labels in \p mask
|
||||
int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates
|
||||
///< the label id of pixel at (i, j), this field might be null
|
||||
float* score; ///< segmentation score map of the input image in CHW format, in which
|
||||
///< score[height * width * k + i * width + j] indicates the score
|
||||
///< of class k at pixel (i, j), this field might be null
|
||||
} mmdeploy_segmentation_t;
|
||||
|
||||
typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t;
|
||||
|
|
|
@ -10,6 +10,7 @@ namespace MMDeploy
|
|||
public int Width;
|
||||
public int Classes;
|
||||
public int* Mask;
|
||||
public float* Score;
|
||||
}
|
||||
#pragma warning restore 0649
|
||||
|
||||
|
@ -34,10 +35,16 @@ namespace MMDeploy
|
|||
public int Classes;
|
||||
|
||||
/// <summary>
|
||||
/// Mask data.
|
||||
/// Mask data, mask[i * width + j] indicates the label id of pixel at (i, j).
|
||||
/// </summary>
|
||||
public int[] Mask;
|
||||
|
||||
/// <summary>
|
||||
/// Score data, score[height * width * k + i * width + j] indicates the score
|
||||
/// of class k at pixel (i, j).
|
||||
/// </summary>
|
||||
public float[] Score;
|
||||
|
||||
/// <summary>
|
||||
/// Initializes a new instance of the <see cref="SegmentorOutput"/> struct.
|
||||
/// </summary>
|
||||
|
@ -45,13 +52,31 @@ namespace MMDeploy
|
|||
/// <param name="width">width.</param>
|
||||
/// <param name="classes">classes.</param>
|
||||
/// <param name="mask">mask.</param>
|
||||
public SegmentorOutput(int height, int width, int classes, int[] mask)
|
||||
/// <param name="score">score.</param>
|
||||
public SegmentorOutput(int height, int width, int classes, int[] mask, float[] score)
|
||||
{
|
||||
Height = height;
|
||||
Width = width;
|
||||
Classes = classes;
|
||||
Mask = new int[Height * Width];
|
||||
Array.Copy(mask, this.Mask, mask.Length);
|
||||
if (mask.Length > 0)
|
||||
{
|
||||
Mask = new int[Height * Width];
|
||||
Array.Copy(mask, this.Mask, mask.Length);
|
||||
}
|
||||
else
|
||||
{
|
||||
Mask = new int[] { };
|
||||
}
|
||||
|
||||
if (score.Length > 0)
|
||||
{
|
||||
Score = new float[Height * Width * Classes];
|
||||
Array.Copy(score, this.Score, score.Length);
|
||||
}
|
||||
else
|
||||
{
|
||||
Score = new float[] { };
|
||||
}
|
||||
}
|
||||
|
||||
internal unsafe SegmentorOutput(CSegment* result)
|
||||
|
@ -59,11 +84,34 @@ namespace MMDeploy
|
|||
Height = result->Height;
|
||||
Width = result->Width;
|
||||
Classes = result->Classes;
|
||||
Mask = new int[Height * Width];
|
||||
int nbytes = Height * Width * sizeof(int);
|
||||
fixed (int* data = this.Mask)
|
||||
if (result->Mask != null)
|
||||
{
|
||||
Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes);
|
||||
Mask = new int[Height * Width];
|
||||
|
||||
int nbytes = Height * Width * sizeof(int);
|
||||
fixed (int* data = this.Mask)
|
||||
{
|
||||
Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Mask = new int[] { };
|
||||
}
|
||||
|
||||
if (result->Score != null)
|
||||
{
|
||||
Score = new float[Height * Width * Classes];
|
||||
|
||||
int nbytes = Height * Width * Classes * sizeof(float);
|
||||
fixed (float* data = this.Score)
|
||||
{
|
||||
Buffer.MemoryCopy(result->Score, data, nbytes, nbytes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Score = new float[] { };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,12 +37,22 @@ class PySegmentor {
|
|||
|
||||
std::vector<py::array> rets(mats.size());
|
||||
for (size_t i = 0; i < mats.size(); ++i) {
|
||||
rets[i] = {
|
||||
{segm[i].height, segm[i].width}, // shape
|
||||
segm[i].mask, // data
|
||||
py::capsule(new Sptr(holder), // handle
|
||||
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
|
||||
};
|
||||
if (segm[i].mask != nullptr) {
|
||||
rets[i] = {
|
||||
{segm[i].height, segm[i].width}, // shape
|
||||
segm[i].mask, // mask
|
||||
py::capsule(new Sptr(holder), // handle
|
||||
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
|
||||
};
|
||||
}
|
||||
if (segm[i].score != nullptr) {
|
||||
rets[i] = {
|
||||
{segm[i].classes, segm[i].height, segm[i].width}, // shape
|
||||
segm[i].score, // score
|
||||
py::capsule(new Sptr(holder), // handle
|
||||
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
|
||||
};
|
||||
}
|
||||
}
|
||||
return rets;
|
||||
}
|
||||
|
|
|
@ -4,8 +4,7 @@ project(mmdeploy_mmaction)
|
|||
|
||||
file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
|
||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||
add_subdirectory(cpu)
|
||||
add_subdirectory(cuda)
|
||||
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||
mmdeploy_operation
|
||||
mmdeploy_transform
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
project(mmdeploy_mmaction_cpu_impl CXX)
|
||||
|
||||
if ("cpu" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
||||
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp)
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
|
||||
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
|
||||
endif ()
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||
mmdeploy::core)
|
||||
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
|
||||
mmdeploy_export(${PROJECT_NAME})
|
||||
endif ()
|
|
@ -1,67 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/codebase/mmaction/format_shape.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy::mmaction::cpu {
|
||||
|
||||
class FormatShapeImpl : public FormatShapeOp {
|
||||
public:
|
||||
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
|
||||
|
||||
protected:
|
||||
Device host_{0, 0};
|
||||
|
||||
const Device& GetDevice() { return host_; }
|
||||
|
||||
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) {
|
||||
Tensor dst(src.desc());
|
||||
TensorShape shape(src.shape().size());
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
shape[i] = src.shape(permutation[i]);
|
||||
}
|
||||
dst.Reshape(shape);
|
||||
int ndim = shape.size();
|
||||
std::vector<int> dst_strides(ndim);
|
||||
std::vector<int> src_strides(ndim);
|
||||
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
|
||||
for (int i = ndim - 2; i >= 0; i--) {
|
||||
dst_strides[i] = dst_strides[i + 1] * shape[i + 1];
|
||||
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
|
||||
}
|
||||
std::vector<int> tmp(ndim);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
tmp[i] = src_strides[permutation[i]];
|
||||
}
|
||||
src_strides.swap(tmp);
|
||||
std::vector<int> coord(ndim, 0);
|
||||
auto dst_data = dst.data<float>();
|
||||
auto src_data = src.data<float>();
|
||||
|
||||
int i;
|
||||
do {
|
||||
dst_data[0] = src_data[0];
|
||||
for (i = ndim - 1; i >= 0; i--) {
|
||||
if (++coord[i] == shape[i]) {
|
||||
coord[i] = 0;
|
||||
dst_data -= (shape[i] - 1) * dst_strides[i];
|
||||
src_data -= (shape[i] - 1) * src_strides[i];
|
||||
} else {
|
||||
dst_data += dst_strides[i];
|
||||
src_data += src_strides[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
} while (i >= 0);
|
||||
return dst;
|
||||
}
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {
|
||||
return std::make_unique<FormatShapeImpl>(std::move(input_format));
|
||||
});
|
||||
|
||||
} // namespace mmdeploy::mmaction::cpu
|
|
@ -1,18 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
||||
return()
|
||||
endif ()
|
||||
|
||||
project(mmdeploy_mmaction_cuda_impl CXX)
|
||||
|
||||
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu)
|
||||
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
|
||||
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
|
||||
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
|
||||
endif ()
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE
|
||||
${CUDA_INCLUDE_DIRS})
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||
mmdeploy::core)
|
||||
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
|
||||
mmdeploy_export(${PROJECT_NAME})
|
|
@ -1,66 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "mmdeploy/codebase/mmaction/format_shape.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy::mmaction::cuda {
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
|
||||
int total, cudaStream_t stream);
|
||||
|
||||
class FormatShapeImpl : public FormatShapeOp {
|
||||
public:
|
||||
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
|
||||
|
||||
protected:
|
||||
const Device& GetDevice() { return device(); }
|
||||
|
||||
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) {
|
||||
Tensor dst(src.desc());
|
||||
TensorShape shape(src.shape().size());
|
||||
for (int i = 0; i < shape.size(); i++) {
|
||||
shape[i] = src.shape(permutation[i]);
|
||||
}
|
||||
dst.Reshape(shape);
|
||||
|
||||
auto ndim = src_dims.size();
|
||||
std::vector<int> dst_dims(ndim);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
dst_dims[i] = src_dims[permutation[i]];
|
||||
}
|
||||
|
||||
std::vector<int> src_strides(ndim);
|
||||
std::vector<int> dst_strides(ndim);
|
||||
std::vector<int> buffer(ndim);
|
||||
buffer.back() = 1;
|
||||
dst_strides.back() = 1;
|
||||
for (int i = ndim - 1; i > 0; i--) {
|
||||
buffer[i - 1] = buffer[i] * src_dims[i];
|
||||
dst_strides[i - 1] = dst_strides[i] * dst_dims[i];
|
||||
}
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
src_strides[i] = buffer[permutation[i]];
|
||||
}
|
||||
|
||||
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
|
||||
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
|
||||
OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides));
|
||||
OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides));
|
||||
|
||||
::mmdeploy::mmaction::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides),
|
||||
dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
|
||||
src.size(), (cudaStream_t)stream().GetNative());
|
||||
return dst;
|
||||
}
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {
|
||||
return std::make_unique<FormatShapeImpl>(std::move(input_format));
|
||||
});
|
||||
|
||||
} // namespace mmdeploy::mmaction::cuda
|
|
@ -1,42 +0,0 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace mmaction {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides,
|
||||
int ndim, int total) {
|
||||
int u = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (u >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int remaining = u;
|
||||
int v = 0;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
int p = remaining / dst_strides[i];
|
||||
remaining -= p * dst_strides[i];
|
||||
v += p * src_strides[i];
|
||||
}
|
||||
dst[u] = src[v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
|
||||
int total, cudaStream_t stream) {
|
||||
int thread_num = 256;
|
||||
int block_num = (total + thread_num - 1) / thread_num;
|
||||
transpose<T>
|
||||
<<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
|
||||
}
|
||||
|
||||
template void Transpose<float>(const float* src, const int* src_strides, float* dst,
|
||||
const int* dst_strides, int ndim, int total, cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace mmaction
|
||||
} // namespace mmdeploy
|
|
@ -10,28 +10,47 @@ using namespace std;
|
|||
namespace mmdeploy::mmaction {
|
||||
|
||||
FormatShape::FormatShape(const Value& args) {
|
||||
auto input_format = args.value("input_format", std::string(""));
|
||||
if (input_format != "NCHW" && input_format != "NCTHW") {
|
||||
input_format_ = args.value("input_format", std::string(""));
|
||||
if (input_format_ != "NCHW" && input_format_ != "NCTHW") {
|
||||
MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'");
|
||||
throw_exception(eInvalidArgument);
|
||||
}
|
||||
format_ = operation::Managed<mmdeploy::mmaction::FormatShapeOp>::Create(input_format);
|
||||
permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
|
||||
}
|
||||
|
||||
Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& output, int clip_len,
|
||||
int num_clips) {
|
||||
Result<void> FormatShape::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
|
||||
auto N = static_cast<int64_t>(images.size());
|
||||
auto H = images[0].shape(1);
|
||||
auto W = images[0].shape(2);
|
||||
auto C = images[0].shape(3);
|
||||
auto& device = operation::gContext().device();
|
||||
auto& stream = operation::gContext().stream();
|
||||
|
||||
TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}};
|
||||
inputs = Tensor(desc);
|
||||
auto offset = 0UL;
|
||||
auto n_item = H * W * C;
|
||||
auto copy_size = n_item * sizeof(float);
|
||||
for (int i = 0; i < N; i++) {
|
||||
auto src_buffer = images[i].buffer();
|
||||
auto dst_buffer = inputs.buffer();
|
||||
OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset));
|
||||
offset += copy_size;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> FormatShape::Format(const std::vector<Tensor>& images, Tensor& output, int clip_len,
|
||||
int num_clips) {
|
||||
Tensor inputs;
|
||||
OUTCOME_TRY(MergeInputs(images, inputs));
|
||||
if (GetDevice().is_host()) {
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
}
|
||||
|
||||
// Tensor dst;
|
||||
if (input_format_ == "NCHW") {
|
||||
OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips));
|
||||
OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output));
|
||||
}
|
||||
if (input_format_ == "NCTHW") {
|
||||
OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips));
|
||||
OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output));
|
||||
}
|
||||
|
||||
TensorShape expand_dim = output.shape();
|
||||
|
@ -41,35 +60,13 @@ Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& out
|
|||
return success();
|
||||
}
|
||||
|
||||
Result<void> FormatShapeOp::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
|
||||
auto N = static_cast<int64_t>(images.size());
|
||||
auto H = images[0].shape(1);
|
||||
auto W = images[0].shape(2);
|
||||
auto C = images[0].shape(3);
|
||||
|
||||
TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}};
|
||||
inputs = Tensor(desc);
|
||||
auto offset = 0UL;
|
||||
auto n_item = H * W * C;
|
||||
auto copy_size = n_item * sizeof(float);
|
||||
for (int i = 0; i < N; i++) {
|
||||
auto src_buffer = images[i].buffer();
|
||||
auto dst_buffer = inputs.buffer();
|
||||
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
|
||||
offset += copy_size;
|
||||
}
|
||||
Result<void> FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) {
|
||||
const vector<int> axes = {0, 3, 1, 2};
|
||||
OUTCOME_TRY(permute_.Apply(src, dst, axes));
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<Tensor> FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
auto C = src.shape(3);
|
||||
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
|
||||
}
|
||||
|
||||
Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
|
||||
Result<void> FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) {
|
||||
auto N = src.shape(0);
|
||||
auto H = src.shape(1);
|
||||
auto W = src.shape(2);
|
||||
|
@ -80,8 +77,9 @@ Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_cli
|
|||
}
|
||||
int M = N / L;
|
||||
src.Reshape({M, L, H, W, C});
|
||||
|
||||
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
|
||||
const vector<int> axes = {0, 4, 1, 2, 3};
|
||||
OUTCOME_TRY(permute_.Apply(src, dst, axes));
|
||||
return success();
|
||||
}
|
||||
|
||||
Result<void> FormatShape::Apply(Value& data) {
|
||||
|
@ -119,7 +117,7 @@ Result<void> FormatShape::Apply(Value& data) {
|
|||
|
||||
Tensor dst;
|
||||
data = Value{};
|
||||
OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips));
|
||||
OUTCOME_TRY(Format(images, dst, clip_len, num_clips));
|
||||
data["img"] = std::move(dst);
|
||||
|
||||
return success();
|
||||
|
@ -127,6 +125,4 @@ Result<void> FormatShape::Apply(Value& data) {
|
|||
|
||||
MMDEPLOY_REGISTER_TRANSFORM(FormatShape);
|
||||
|
||||
MMDEPLOY_DEFINE_REGISTRY(FormatShapeOp);
|
||||
|
||||
} // namespace mmdeploy::mmaction
|
||||
|
|
|
@ -1,50 +1,38 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_
|
||||
#define MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_
|
||||
#ifndef MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_
|
||||
#define MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/operation/managed.h"
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
#include "mmdeploy/preprocess/transform/transform.h"
|
||||
|
||||
namespace mmdeploy::mmaction {
|
||||
|
||||
class FormatShapeOp : public operation::Operation {
|
||||
public:
|
||||
explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){};
|
||||
|
||||
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
|
||||
int num_clips);
|
||||
|
||||
virtual const Device& GetDevice() = 0;
|
||||
|
||||
virtual Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
|
||||
const std::vector<int>& permutation) = 0;
|
||||
|
||||
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips);
|
||||
|
||||
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips);
|
||||
|
||||
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
|
||||
|
||||
protected:
|
||||
std::string input_format_;
|
||||
};
|
||||
|
||||
class FormatShape : public Transform {
|
||||
public:
|
||||
explicit FormatShape(const Value& args);
|
||||
|
||||
Result<void> Apply(Value& data) override;
|
||||
|
||||
private:
|
||||
operation::Managed<FormatShapeOp> format_;
|
||||
};
|
||||
Result<void> Format(const std::vector<Tensor>& images, Tensor& output, int clip_len,
|
||||
int num_clips);
|
||||
|
||||
MMDEPLOY_DECLARE_REGISTRY(FormatShapeOp, std::unique_ptr<FormatShapeOp>(std::string input_format));
|
||||
Result<void> FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst);
|
||||
|
||||
Result<void> FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst);
|
||||
|
||||
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
|
||||
|
||||
private:
|
||||
std::string input_format_;
|
||||
operation::Managed<operation::Permute> permute_;
|
||||
};
|
||||
|
||||
} // namespace mmdeploy::mmaction
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_
|
||||
#define MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_
|
||||
#ifndef MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_
|
||||
#define MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_
|
||||
|
||||
#include "mmdeploy/codebase/common.h"
|
||||
#include "mmdeploy/core/device.h"
|
||||
|
|
|
@ -4,7 +4,9 @@ project(mmdeploy_mmseg)
|
|||
|
||||
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
|
||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||
mmdeploy_opencv_utils
|
||||
mmdeploy_operation)
|
||||
add_library(mmdeploy::mmseg ALIAS ${PROJECT_NAME})
|
||||
|
||||
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} segmentor CACHE INTERNAL "")
|
||||
|
|
|
@ -12,10 +12,11 @@ namespace mmdeploy::mmseg {
|
|||
|
||||
struct SegmentorOutput {
|
||||
Tensor mask;
|
||||
Tensor score;
|
||||
int height;
|
||||
int width;
|
||||
int classes;
|
||||
MMDEPLOY_ARCHIVE_MEMBERS(mask, height, width, classes);
|
||||
MMDEPLOY_ARCHIVE_MEMBERS(mask, score, height, width, classes);
|
||||
};
|
||||
|
||||
MMDEPLOY_DECLARE_CODEBASE(MMSegmentation, mmseg);
|
||||
|
|
|
@ -5,6 +5,8 @@
|
|||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "mmdeploy/operation/managed.h"
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
#include "mmdeploy/preprocess/transform/transform.h"
|
||||
#include "opencv_utils.h"
|
||||
|
||||
|
@ -18,7 +20,10 @@ class ResizeMask : public MMSegmentation {
|
|||
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
|
||||
try {
|
||||
classes_ = cfg["params"]["num_classes"].get<int>();
|
||||
with_argmax_ = cfg["params"].value("with_argmax", true);
|
||||
little_endian_ = IsLittleEndian();
|
||||
::mmdeploy::operation::Context ctx(Device("cpu"), stream_);
|
||||
permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
|
||||
} catch (const std::exception &e) {
|
||||
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
|
||||
throw_exception(eInvalidArgument);
|
||||
|
@ -31,40 +36,71 @@ class ResizeMask : public MMSegmentation {
|
|||
auto mask = inference_result["output"].get<Tensor>();
|
||||
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
|
||||
mask.shape(), mask.data_type());
|
||||
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
|
||||
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
|
||||
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
if ((mask.shape(1) != 1) && with_argmax_) {
|
||||
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
|
||||
mask.shape());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
if ((mask.data_type() != DataType::kFLOAT) && !with_argmax_) {
|
||||
MMDEPLOY_ERROR("probability feat map only support float32 output");
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto channel = (int)mask.shape(1);
|
||||
auto height = (int)mask.shape(2);
|
||||
auto width = (int)mask.shape(3);
|
||||
auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get<int>();
|
||||
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
|
||||
Device host{"cpu"};
|
||||
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
if (!with_argmax_) {
|
||||
// (C, H, W) -> (H, W, C)
|
||||
::mmdeploy::operation::Context ctx(host, stream_);
|
||||
std::vector<int> axes = {0, 2, 3, 1};
|
||||
OUTCOME_TRY(permute_.Apply(host_tensor, host_tensor, axes));
|
||||
}
|
||||
|
||||
OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type()));
|
||||
OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type(), channel));
|
||||
cv::Mat mask_mat(height, width, cv_type, host_tensor.data());
|
||||
|
||||
if (mask_mat.channels() > 1) {
|
||||
cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1);
|
||||
}
|
||||
if (mask_mat.type() != CV_32S) {
|
||||
mask_mat.convertTo(mask_mat, CV_32S);
|
||||
cv::Mat resized_mask;
|
||||
cv::Mat resized_score;
|
||||
|
||||
Tensor tensor_mask{};
|
||||
Tensor tensor_score{};
|
||||
|
||||
if (with_argmax_) {
|
||||
// mask
|
||||
if (mask_mat.channels() > 1) {
|
||||
cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1);
|
||||
}
|
||||
if (mask_mat.type() != CV_32S) {
|
||||
mask_mat.convertTo(mask_mat, CV_32S);
|
||||
}
|
||||
resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest");
|
||||
tensor_mask = cpu::CVMat2Tensor(resized_mask);
|
||||
} else {
|
||||
// score
|
||||
resized_score = cpu::Resize(mask_mat, input_height, input_width, "bilinear");
|
||||
tensor_score = cpu::CVMat2Tensor(resized_score);
|
||||
std::vector<int> axes = {0, 3, 1, 2};
|
||||
::mmdeploy::operation::Context ctx(host, stream_);
|
||||
OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes));
|
||||
}
|
||||
|
||||
cv::Mat resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest");
|
||||
|
||||
SegmentorOutput output{cpu::CVMat2Tensor(resized_mask), input_height, input_width, classes_};
|
||||
SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_};
|
||||
return to_value(output);
|
||||
}
|
||||
|
||||
private:
|
||||
static Result<int> GetCvType(DataType type) {
|
||||
static Result<int> GetCvType(DataType type, int channel) {
|
||||
switch (type) {
|
||||
case DataType::kFLOAT:
|
||||
return CV_32F;
|
||||
return CV_32FC(channel);
|
||||
case DataType::kINT64:
|
||||
return CV_32SC2;
|
||||
case DataType::kINT32:
|
||||
|
@ -84,7 +120,9 @@ class ResizeMask : public MMSegmentation {
|
|||
}
|
||||
|
||||
protected:
|
||||
::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_;
|
||||
int classes_{};
|
||||
bool with_argmax_{true};
|
||||
bool little_endian_;
|
||||
};
|
||||
|
||||
|
|
|
@ -11,7 +11,8 @@ set(SRCS resize.cpp
|
|||
crop.cpp
|
||||
flip.cpp
|
||||
warp_affine.cpp
|
||||
crop_resize_pad.cpp)
|
||||
crop_resize_pad.cpp
|
||||
permute.cpp)
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
#include "mmdeploy/utils/opencv/opencv_utils.h"
|
||||
|
||||
namespace mmdeploy::operation::cpu {
|
||||
|
||||
class PermuteImpl : public Permute {
|
||||
public:
|
||||
explicit PermuteImpl() {}
|
||||
|
||||
Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) override {
|
||||
int ndim = src.shape().size();
|
||||
if (ndim != axes.size()) {
|
||||
MMDEPLOY_ERROR("The size of axes should be equal to src, {} vs {}", axes.size(), ndim);
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
std::vector<int> axes_vis(ndim, 0);
|
||||
for (const auto& x : axes) {
|
||||
if (x < 0 || x >= ndim || axes_vis[x]) {
|
||||
MMDEPLOY_ERROR("Invalid axes");
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
axes_vis[x] = 1;
|
||||
}
|
||||
|
||||
Tensor dst_tensor(src.desc());
|
||||
auto src_dims = src.shape();
|
||||
TensorShape dst_dims(ndim);
|
||||
for (int i = 0; i < src_dims.size(); i++) {
|
||||
dst_dims[i] = src_dims[axes[i]];
|
||||
}
|
||||
dst_tensor.Reshape(dst_dims);
|
||||
|
||||
std::vector<int> dst_strides(ndim);
|
||||
std::vector<int> src_strides(ndim);
|
||||
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
|
||||
for (int i = ndim - 2; i >= 0; i--) {
|
||||
dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1];
|
||||
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
|
||||
}
|
||||
|
||||
std::vector<int> tmp(ndim);
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
tmp[i] = src_strides[axes[i]];
|
||||
}
|
||||
src_strides.swap(tmp);
|
||||
|
||||
if (src.data_type() == DataType::kINT8) {
|
||||
OUTCOME_TRY(PermuteDispatch<uint8_t>(src, dst_tensor, src_strides, dst_strides));
|
||||
} else if (src.data_type() == DataType::kFLOAT) {
|
||||
OUTCOME_TRY(PermuteDispatch<float>(src, dst_tensor, src_strides, dst_strides));
|
||||
} else {
|
||||
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
dst = std::move(dst_tensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Result<void> PermuteDispatch(const Tensor& src, Tensor& dst, const std::vector<int>& src_strides,
|
||||
const std::vector<int>& dst_strides) {
|
||||
auto shape = dst.shape();
|
||||
int ndim = src.shape().size();
|
||||
std::vector<int> coord(ndim, 0);
|
||||
auto dst_data = dst.data<T>();
|
||||
auto src_data = src.data<T>();
|
||||
|
||||
int i;
|
||||
do {
|
||||
dst_data[0] = src_data[0];
|
||||
for (i = ndim - 1; i >= 0; i--) {
|
||||
if (++coord[i] == shape[i]) {
|
||||
coord[i] = 0;
|
||||
dst_data -= (shape[i] - 1) * dst_strides[i];
|
||||
src_data -= (shape[i] - 1) * src_strides[i];
|
||||
} else {
|
||||
dst_data += dst_strides[i];
|
||||
src_data += src_strides[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
} while (i >= 0);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cpu, 0), []() { return std::make_unique<PermuteImpl>(); });
|
||||
|
||||
} // namespace mmdeploy::operation::cpu
|
|
@ -19,7 +19,9 @@ set(SRCS resize.cpp
|
|||
crop.cu
|
||||
flip.cpp
|
||||
warp_affine.cpp
|
||||
crop_resize_pad.cpp)
|
||||
crop_resize_pad.cpp
|
||||
permute.cpp
|
||||
permute.cu)
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/operation/cuda/permute.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
|
||||
namespace mmdeploy::operation::cuda {
|
||||
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides,
|
||||
int ndim, int total, cudaStream_t stream);
|
||||
}
|
||||
|
||||
class PermuteImpl : public Permute {
|
||||
public:
|
||||
explicit PermuteImpl() {}
|
||||
|
||||
Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) override {
|
||||
int ndim = src.shape().size();
|
||||
if (ndim != axes.size()) {
|
||||
MMDEPLOY_ERROR("The size of axes should be equal of src, {} vs {}", axes.size(), ndim);
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
if (ndim > MAX_PERMUTE_DIM) {
|
||||
MMDEPLOY_ERROR("Only support ndim < {}", MAX_PERMUTE_DIM);
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
std::vector<int> axes_vis(ndim, 0);
|
||||
for (const auto& x : axes) {
|
||||
if (x < 0 || x >= ndim || axes_vis[x]) {
|
||||
MMDEPLOY_ERROR("Invalid axes");
|
||||
return Status(eInvalidArgument);
|
||||
}
|
||||
axes_vis[x] = 1;
|
||||
}
|
||||
|
||||
Tensor dst_tensor(src.desc());
|
||||
auto src_dims = src.shape();
|
||||
TensorShape dst_dims(ndim);
|
||||
for (int i = 0; i < src_dims.size(); i++) {
|
||||
dst_dims[i] = src_dims[axes[i]];
|
||||
}
|
||||
dst_tensor.Reshape(dst_dims);
|
||||
|
||||
TensorStride dst_strides;
|
||||
TensorStride src_strides;
|
||||
|
||||
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
|
||||
for (int i = ndim - 2; i >= 0; i--) {
|
||||
dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1];
|
||||
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
|
||||
}
|
||||
|
||||
TensorStride tmp;
|
||||
for (int i = 0; i < ndim; i++) {
|
||||
tmp[i] = src_strides[axes[i]];
|
||||
}
|
||||
src_strides = tmp;
|
||||
|
||||
if (src.data_type() == DataType::kINT8) {
|
||||
OUTCOME_TRY(PermuteDispatch<uint8_t>(src, dst_tensor, src_strides, dst_strides));
|
||||
} else if (src.data_type() == DataType::kFLOAT) {
|
||||
OUTCOME_TRY(PermuteDispatch<float>(src, dst_tensor, src_strides, dst_strides));
|
||||
} else {
|
||||
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
dst = std::move(dst_tensor);
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Result<void> PermuteDispatch(const Tensor& src, Tensor& dst, const TensorStride& src_strides,
|
||||
const TensorStride& dst_strides) {
|
||||
auto src_data = src.data<T>();
|
||||
auto dst_data = dst.data<T>();
|
||||
auto ndim = src.shape().size();
|
||||
auto total = src.size();
|
||||
impl::Permute(src_data, src_strides, dst_data, dst_strides, ndim, total,
|
||||
GetNative<cudaStream_t>(stream()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cuda, 0),
|
||||
[]() { return std::make_unique<PermuteImpl>(); });
|
||||
|
||||
} // namespace mmdeploy::operation::cuda
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "mmdeploy/operation/cuda/permute.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace operation {
|
||||
namespace cuda {
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
__global__ void permute(const T* src, const TensorStride src_strides, T* dst,
|
||||
const TensorStride dst_strides, int ndim, int total) {
|
||||
int u = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (u >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
int remaining = u;
|
||||
int v = 0;
|
||||
for (size_t i = 0; i < ndim; i++) {
|
||||
int p = remaining / dst_strides.v_[i];
|
||||
remaining -= p * dst_strides.v_[i];
|
||||
v += p * src_strides.v_[i];
|
||||
}
|
||||
dst[u] = src[v];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides,
|
||||
int ndim, int total, cudaStream_t stream) {
|
||||
int thread_num = 256;
|
||||
int block_num = (total + thread_num - 1) / thread_num;
|
||||
permute<T><<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
|
||||
}
|
||||
|
||||
template void Permute<float>(const float* src, const TensorStride& src_strides, float* dst,
|
||||
const TensorStride& dst_strides, int ndim, int total,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void Permute<uint8_t>(const uint8_t* src, const TensorStride& src_strides, uint8_t* dst,
|
||||
const TensorStride& dst_strides, int ndim, int total,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace cuda
|
||||
} // namespace operation
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#ifndef MMDEPLOY_OPERATION_CUDA_PERMUTE_H_
|
||||
#define MMDEPLOY_OPERATION_CUDA_PERMUTE_H_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace operation {
|
||||
namespace cuda {
|
||||
|
||||
const int MAX_PERMUTE_DIM = 8;
|
||||
|
||||
struct TensorStride {
|
||||
int v_[MAX_PERMUTE_DIM];
|
||||
int& operator[](size_t idx) { return v_[idx]; }
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace operation
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif // MMDEPLOY_OPERATION_CUDA_PERMUTE_H_
|
|
@ -14,5 +14,6 @@ MMDEPLOY_DEFINE_REGISTRY(Crop);
|
|||
MMDEPLOY_DEFINE_REGISTRY(Flip);
|
||||
MMDEPLOY_DEFINE_REGISTRY(WarpAffine);
|
||||
MMDEPLOY_DEFINE_REGISTRY(CropResizePad);
|
||||
MMDEPLOY_DEFINE_REGISTRY(Permute);
|
||||
|
||||
} // namespace mmdeploy::operation
|
||||
|
|
|
@ -92,6 +92,11 @@ class CropResizePad : public Operation {
|
|||
};
|
||||
|
||||
MMDEPLOY_DECLARE_REGISTRY(CropResizePad, unique_ptr<CropResizePad>());
|
||||
class Permute : public Operation {
|
||||
public:
|
||||
virtual Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) = 0;
|
||||
};
|
||||
MMDEPLOY_DECLARE_REGISTRY(Permute, unique_ptr<Permute>());
|
||||
|
||||
} // namespace mmdeploy::operation
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
|
|||
const std::string& interpolation) {
|
||||
cv::Mat dst(dst_height, dst_width, src.type());
|
||||
auto method = GetInterpolationMethod(interpolation).value();
|
||||
cv::resize(src, dst, dst.size(), method);
|
||||
cv::resize(src, dst, dst.size(), 0, 0, method);
|
||||
return dst;
|
||||
}
|
||||
|
||||
|
|
|
@ -75,21 +75,50 @@ namespace image_segmentation
|
|||
unsafe
|
||||
{
|
||||
byte* data = colorMask.DataPointer;
|
||||
fixed (int* _label = output[0].Mask)
|
||||
if (output[0].Mask.Length > 0)
|
||||
{
|
||||
int* label = _label;
|
||||
for (int i = 0; i < output[0].Height; i++)
|
||||
fixed (int* _label = output[0].Mask)
|
||||
{
|
||||
for (int j = 0; j < output[0].Width; j++)
|
||||
int* label = _label;
|
||||
for (int i = 0; i < output[0].Height; i++)
|
||||
{
|
||||
data[0] = palette[*label][0];
|
||||
data[1] = palette[*label][1];
|
||||
data[2] = palette[*label][2];
|
||||
data += 3;
|
||||
label++;
|
||||
for (int j = 0; j < output[0].Width; j++)
|
||||
{
|
||||
data[0] = palette[*label][0];
|
||||
data[1] = palette[*label][1];
|
||||
data[2] = palette[*label][2];
|
||||
data += 3;
|
||||
label++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int pos = 0;
|
||||
fixed (float* _score = output[0].Score)
|
||||
{
|
||||
float *score = _score;
|
||||
int total = output[0].Height * output[0].Width;
|
||||
for (int i = 0; i < output[0].Height; i++)
|
||||
{
|
||||
for (int j = 0; j < output[0].Width; j++)
|
||||
{
|
||||
List<Tuple<float, int>> scores = new List<Tuple<float, int>>();
|
||||
for (int k = 0; k < output[0].Classes; k++)
|
||||
{
|
||||
scores.Add(new Tuple<float, int>(score[k * total + i * output[0].Width + j], k));
|
||||
}
|
||||
scores.Sort();
|
||||
data[0] = palette[scores[^1].Item2][0];
|
||||
data[1] = palette[scores[^1].Item2][1];
|
||||
data[2] = palette[scores[^1].Item2][2];
|
||||
data += 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
colorMask = imgs[0] * 0.5 + colorMask * 0.5;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <opencv2/imgcodecs/imgcodecs.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#include <random>
|
||||
|
@ -59,8 +60,25 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3);
|
||||
int pos = 0;
|
||||
int total = color_mask.rows * color_mask.cols;
|
||||
std::vector<int> idxs(result->classes);
|
||||
for (auto iter = color_mask.begin<cv::Vec3b>(); iter != color_mask.end<cv::Vec3b>(); ++iter) {
|
||||
*iter = palette[result->mask[pos++]];
|
||||
// output mask
|
||||
if (result->mask) {
|
||||
*iter = palette[result->mask[pos++]];
|
||||
}
|
||||
// output score
|
||||
if (result->score) {
|
||||
std::iota(idxs.begin(), idxs.end(), 0);
|
||||
auto k =
|
||||
std::max_element(idxs.begin(), idxs.end(),
|
||||
[&](int i, int j) {
|
||||
return result->score[i * total + pos] < result->score[j * total + pos];
|
||||
}) -
|
||||
idxs.begin();
|
||||
*iter = palette[k];
|
||||
pos += 1;
|
||||
}
|
||||
}
|
||||
|
||||
img = img * 0.5 + color_mask * 0.5;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "mmdeploy/segmentor.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
#include <opencv2/imgcodecs/imgcodecs.hpp>
|
||||
#include <opencv2/imgproc/imgproc.hpp>
|
||||
#include <random>
|
||||
|
@ -47,8 +48,25 @@ int main(int argc, char* argv[]) {
|
|||
|
||||
cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3);
|
||||
int pos = 0;
|
||||
int total = color_mask.rows * color_mask.cols;
|
||||
std::vector<int> idxs(result->classes);
|
||||
for (auto iter = color_mask.begin<cv::Vec3b>(); iter != color_mask.end<cv::Vec3b>(); ++iter) {
|
||||
*iter = palette[result->mask[pos++]];
|
||||
// output mask
|
||||
if (result->mask) {
|
||||
*iter = palette[result->mask[pos++]];
|
||||
}
|
||||
// output score
|
||||
if (result->score) {
|
||||
std::iota(idxs.begin(), idxs.end(), 0);
|
||||
auto k =
|
||||
std::max_element(idxs.begin(), idxs.end(),
|
||||
[&](int i, int j) {
|
||||
return result->score[pos + i * total] < result->score[pos + j * total];
|
||||
}) -
|
||||
idxs.begin();
|
||||
*iter = palette[k];
|
||||
pos += 1;
|
||||
}
|
||||
}
|
||||
|
||||
img = img * 0.5 + color_mask * 0.5;
|
||||
|
|
|
@ -35,6 +35,8 @@ def main():
|
|||
segmentor = Segmentor(
|
||||
model_path=args.model_path, device_name=args.device_name, device_id=0)
|
||||
seg = segmentor(img)
|
||||
if seg.dtype == np.float32:
|
||||
seg = np.argmax(seg, axis=0)
|
||||
|
||||
palette = get_palette()
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "catch.hpp"
|
||||
#include "mmdeploy/core/mat.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/operation/managed.h"
|
||||
#include "mmdeploy/operation/vision.h"
|
||||
#include "mmdeploy/preprocess/transform/transform.h"
|
||||
#include "test_resource.h"
|
||||
#include "test_utils.h"
|
||||
|
||||
using namespace mmdeploy;
|
||||
using namespace framework;
|
||||
using namespace std;
|
||||
using namespace mmdeploy::test;
|
||||
|
||||
template <typename T>
|
||||
bool CheckEqual(const Tensor& res, const vector<T>& expected) {
|
||||
auto r = res.data<T>();
|
||||
auto e = expected.data();
|
||||
for (int i = 0; i < expected.size(); i++) {
|
||||
if (r[i] != e[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestPermute(const Tensor& src, const vector<int>& axes, const vector<T>& expected) {
|
||||
auto gResource = MMDeployTestResources::Get();
|
||||
for (auto const& device_name : gResource.device_names()) {
|
||||
Device device{device_name.c_str()};
|
||||
Stream stream{device};
|
||||
::mmdeploy::operation::Context ctx(device, stream);
|
||||
auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
|
||||
Tensor dst;
|
||||
auto ret = permute.Apply(src, dst, axes);
|
||||
REQUIRE(!ret.has_error());
|
||||
const Device kHost{"cpu"};
|
||||
auto host_tensor = MakeAvailableOnDevice(dst, kHost, stream);
|
||||
REQUIRE(CheckEqual(host_tensor.value(), expected));
|
||||
}
|
||||
}
|
||||
|
||||
void TestPermuteWrongArgs(const Tensor& src) {
|
||||
int sz = src.shape().size();
|
||||
vector<int> oaxes(sz);
|
||||
std::iota(oaxes.begin(), oaxes.end(), 0);
|
||||
|
||||
auto gResource = MMDeployTestResources::Get();
|
||||
for (auto const& device_name : gResource.device_names()) {
|
||||
Device device{device_name.c_str()};
|
||||
Stream stream{device};
|
||||
::mmdeploy::operation::Context ctx(device, stream);
|
||||
auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
|
||||
Tensor dst;
|
||||
{
|
||||
auto axes = oaxes;
|
||||
axes[0]--;
|
||||
auto ret = permute.Apply(src, dst, axes);
|
||||
REQUIRE(ret.has_error());
|
||||
}
|
||||
{
|
||||
auto axes = oaxes;
|
||||
axes.back()++;
|
||||
auto ret = permute.Apply(src, dst, axes);
|
||||
REQUIRE(ret.has_error());
|
||||
}
|
||||
{
|
||||
auto axes = oaxes;
|
||||
axes[0] = axes[1];
|
||||
auto ret = permute.Apply(src, dst, axes);
|
||||
REQUIRE(ret.has_error());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("operation Permute", "[permute]") {
|
||||
const Device kHost{"cpu"};
|
||||
const int kSize = 2 * 3 * 2 * 4;
|
||||
vector<uint8_t> data(kSize);
|
||||
std::iota(data.begin(), data.end(), 0); // [0, 48)
|
||||
TensorDesc desc = {kHost, DataType::kINT8, {kSize}};
|
||||
Tensor tensor(desc);
|
||||
memcpy(tensor.data(), data.data(), data.size() * sizeof(uint8_t));
|
||||
|
||||
SECTION("permute: wrong axes") {
|
||||
Tensor src = tensor;
|
||||
src.Reshape({6, 8});
|
||||
TestPermuteWrongArgs(src);
|
||||
}
|
||||
|
||||
SECTION("permute: dims 4") {
|
||||
Tensor src = tensor;
|
||||
src.Reshape({2, 3, 2, 4});
|
||||
vector<int> axes = {1, 0, 3, 2};
|
||||
vector<uint8_t> expected = {0, 4, 1, 5, 2, 6, 3, 7, 24, 28, 25, 29, 26, 30, 27, 31,
|
||||
8, 12, 9, 13, 10, 14, 11, 15, 32, 36, 33, 37, 34, 38, 35, 39,
|
||||
16, 20, 17, 21, 18, 22, 19, 23, 40, 44, 41, 45, 42, 46, 43, 47};
|
||||
Tensor dst(src.desc());
|
||||
memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t));
|
||||
TestPermute(src, axes, expected);
|
||||
}
|
||||
|
||||
SECTION("permute: dims 5") {
|
||||
Tensor src = tensor;
|
||||
src.Reshape({2, 3, 1, 2, 4});
|
||||
vector<int> axes = {2, 0, 1, 4, 3};
|
||||
vector<uint8_t> expected = {0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15,
|
||||
16, 20, 17, 21, 18, 22, 19, 23, 24, 28, 25, 29, 26, 30, 27, 31,
|
||||
32, 36, 33, 37, 34, 38, 35, 39, 40, 44, 41, 45, 42, 46, 43, 47};
|
||||
Tensor dst(src.desc());
|
||||
memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t));
|
||||
TestPermute(src, axes, expected);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue