[Feature] Support feature map output for mmsegmentation ()

* add feature map output for mmseg

* update api

* update demo

* fix return

* update format_shape

* fix lint

* update csharp demo

* update python demo && api

* fix coreml build

* fix lint

* better sort

* update

* update cpp demo & add missing header

* change to CHW

* update csharp api

* update isort version to 5.12.0

* fix python api

* fix log

* more detail api docs

* isort support python3.7

* remove isort change

* remove whitespace

* axes check

* remove FormatShapeImpl

* minor

* add permute tc

* remove stride buffer

(cherry picked from commit b85f34141b)
pull/1726/head
Chen Xin 2023-02-03 20:47:55 +08:00 committed by zhangli
parent a9f8d8c951
commit 1f56eea807
30 changed files with 668 additions and 332 deletions

View File

@ -119,10 +119,17 @@ int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation
results_ptr->height = segmentor_output.height;
results_ptr->width = segmentor_output.width;
results_ptr->classes = segmentor_output.classes;
auto mask_size = results_ptr->height * results_ptr->width;
auto& mask = segmentor_output.mask;
results_ptr->mask = mask.data<int>();
buffers[i] = mask.buffer();
auto& score = segmentor_output.score;
results_ptr->mask = nullptr;
results_ptr->score = nullptr;
if (mask.shape().size()) {
results_ptr->mask = mask.data<int>();
buffers[i] = mask.buffer();
} else {
results_ptr->score = score.data<float>();
buffers[i] = score.buffer();
}
}
*results = results_data;

View File

@ -17,11 +17,14 @@ extern "C" {
#endif
typedef struct mmdeploy_segmentation_t {
int height; ///< height of \p mask that equals to the input image's height
int width; ///< width of \p mask that equals to the input image's width
int classes; ///< the number of labels in \p mask
int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates
///< the label id of pixel at (i, j)
int height; ///< height of \p mask that equals to the input image's height
int width; ///< width of \p mask that equals to the input image's width
int classes; ///< the number of labels in \p mask
int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates
///< the label id of pixel at (i, j), this field might be null
float* score; ///< segmentation score map of the input image in CHW format, in which
///< score[height * width * k + i * width + j] indicates the score
///< of class k at pixel (i, j), this field might be null
} mmdeploy_segmentation_t;
typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t;

View File

@ -10,6 +10,7 @@ namespace MMDeploy
public int Width;
public int Classes;
public int* Mask;
public float* Score;
}
#pragma warning restore 0649
@ -34,10 +35,16 @@ namespace MMDeploy
public int Classes;
/// <summary>
/// Mask data.
/// Mask data, mask[i * width + j] indicates the label id of pixel at (i, j).
/// </summary>
public int[] Mask;
/// <summary>
/// Score data, score[height * width * k + i * width + j] indicates the score
/// of class k at pixel (i, j).
/// </summary>
public float[] Score;
/// <summary>
/// Initializes a new instance of the <see cref="SegmentorOutput"/> struct.
/// </summary>
@ -45,13 +52,31 @@ namespace MMDeploy
/// <param name="width">width.</param>
/// <param name="classes">classes.</param>
/// <param name="mask">mask.</param>
public SegmentorOutput(int height, int width, int classes, int[] mask)
/// <param name="score">score.</param>
public SegmentorOutput(int height, int width, int classes, int[] mask, float[] score)
{
Height = height;
Width = width;
Classes = classes;
Mask = new int[Height * Width];
Array.Copy(mask, this.Mask, mask.Length);
if (mask.Length > 0)
{
Mask = new int[Height * Width];
Array.Copy(mask, this.Mask, mask.Length);
}
else
{
Mask = new int[] { };
}
if (score.Length > 0)
{
Score = new float[Height * Width * Classes];
Array.Copy(score, this.Score, score.Length);
}
else
{
Score = new float[] { };
}
}
internal unsafe SegmentorOutput(CSegment* result)
@ -59,11 +84,34 @@ namespace MMDeploy
Height = result->Height;
Width = result->Width;
Classes = result->Classes;
Mask = new int[Height * Width];
int nbytes = Height * Width * sizeof(int);
fixed (int* data = this.Mask)
if (result->Mask != null)
{
Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes);
Mask = new int[Height * Width];
int nbytes = Height * Width * sizeof(int);
fixed (int* data = this.Mask)
{
Buffer.MemoryCopy(result->Mask, data, nbytes, nbytes);
}
}
else
{
Mask = new int[] { };
}
if (result->Score != null)
{
Score = new float[Height * Width * Classes];
int nbytes = Height * Width * Classes * sizeof(float);
fixed (float* data = this.Score)
{
Buffer.MemoryCopy(result->Score, data, nbytes, nbytes);
}
}
else
{
Score = new float[] { };
}
}
}

View File

@ -37,12 +37,22 @@ class PySegmentor {
std::vector<py::array> rets(mats.size());
for (size_t i = 0; i < mats.size(); ++i) {
rets[i] = {
{segm[i].height, segm[i].width}, // shape
segm[i].mask, // data
py::capsule(new Sptr(holder), // handle
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
};
if (segm[i].mask != nullptr) {
rets[i] = {
{segm[i].height, segm[i].width}, // shape
segm[i].mask, // mask
py::capsule(new Sptr(holder), // handle
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
};
}
if (segm[i].score != nullptr) {
rets[i] = {
{segm[i].classes, segm[i].height, segm[i].width}, // shape
segm[i].score, // score
py::capsule(new Sptr(holder), // handle
[](void* p) { delete reinterpret_cast<Sptr*>(p); }) //
};
}
}
return rets;
}

View File

@ -4,8 +4,7 @@ project(mmdeploy_mmaction)
file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
add_subdirectory(cpu)
add_subdirectory(cuda)
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy_operation
mmdeploy_transform

View File

@ -1,15 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
project(mmdeploy_mmaction_cpu_impl CXX)
if ("cpu" IN_LIST MMDEPLOY_TARGET_DEVICES)
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp)
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
endif ()
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy::core)
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
mmdeploy_export(${PROJECT_NAME})
endif ()

View File

@ -1,67 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/codebase/mmaction/format_shape.h"
#include "mmdeploy/core/utils/device_utils.h"
using namespace std;
namespace mmdeploy::mmaction::cpu {
class FormatShapeImpl : public FormatShapeOp {
public:
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
protected:
Device host_{0, 0};
const Device& GetDevice() { return host_; }
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) {
Tensor dst(src.desc());
TensorShape shape(src.shape().size());
for (int i = 0; i < shape.size(); i++) {
shape[i] = src.shape(permutation[i]);
}
dst.Reshape(shape);
int ndim = shape.size();
std::vector<int> dst_strides(ndim);
std::vector<int> src_strides(ndim);
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
dst_strides[i] = dst_strides[i + 1] * shape[i + 1];
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
}
std::vector<int> tmp(ndim);
for (int i = 0; i < ndim; i++) {
tmp[i] = src_strides[permutation[i]];
}
src_strides.swap(tmp);
std::vector<int> coord(ndim, 0);
auto dst_data = dst.data<float>();
auto src_data = src.data<float>();
int i;
do {
dst_data[0] = src_data[0];
for (i = ndim - 1; i >= 0; i--) {
if (++coord[i] == shape[i]) {
coord[i] = 0;
dst_data -= (shape[i] - 1) * dst_strides[i];
src_data -= (shape[i] - 1) * src_strides[i];
} else {
dst_data += dst_strides[i];
src_data += src_strides[i];
break;
}
}
} while (i >= 0);
return dst;
}
};
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cpu, 0), [](std::string input_format) {
return std::make_unique<FormatShapeImpl>(std::move(input_format));
});
} // namespace mmdeploy::mmaction::cpu

View File

@ -1,18 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
if (NOT "cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
return()
endif ()
project(mmdeploy_mmaction_cuda_impl CXX)
add_library(${PROJECT_NAME} OBJECT format_shape_impl.cpp transpose.cu)
set_target_properties(${PROJECT_NAME} PROPERTIES POSITION_INDEPENDENT_CODE 1)
if (NOT (MMDEPLOY_SHARED_LIBS OR MSVC))
target_compile_options(${PROJECT_NAME} PRIVATE $<$<COMPILE_LANGUAGE:CXX>:-fvisibility=hidden>)
endif ()
target_include_directories(${PROJECT_NAME} PRIVATE
${CUDA_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy::core)
target_link_libraries(mmdeploy_mmaction PRIVATE ${PROJECT_NAME})
mmdeploy_export(${PROJECT_NAME})

View File

@ -1,66 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "cuda_runtime.h"
#include "mmdeploy/codebase/mmaction/format_shape.h"
#include "mmdeploy/core/utils/device_utils.h"
using namespace std;
namespace mmdeploy::mmaction::cuda {
template <typename T>
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
int total, cudaStream_t stream);
class FormatShapeImpl : public FormatShapeOp {
public:
explicit FormatShapeImpl(std::string input_format) : FormatShapeOp(std::move(input_format)) {}
protected:
const Device& GetDevice() { return device(); }
Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) {
Tensor dst(src.desc());
TensorShape shape(src.shape().size());
for (int i = 0; i < shape.size(); i++) {
shape[i] = src.shape(permutation[i]);
}
dst.Reshape(shape);
auto ndim = src_dims.size();
std::vector<int> dst_dims(ndim);
for (int i = 0; i < ndim; i++) {
dst_dims[i] = src_dims[permutation[i]];
}
std::vector<int> src_strides(ndim);
std::vector<int> dst_strides(ndim);
std::vector<int> buffer(ndim);
buffer.back() = 1;
dst_strides.back() = 1;
for (int i = ndim - 1; i > 0; i--) {
buffer[i - 1] = buffer[i] * src_dims[i];
dst_strides[i - 1] = dst_strides[i] * dst_dims[i];
}
for (int i = 0; i < ndim; ++i) {
src_strides[i] = buffer[permutation[i]];
}
Buffer _src_strides(Device("cuda"), sizeof(int) * ndim);
Buffer _dst_strides(Device("cuda"), sizeof(int) * ndim);
OUTCOME_TRY(stream().Copy(src_strides.data(), _src_strides));
OUTCOME_TRY(stream().Copy(dst_strides.data(), _dst_strides));
::mmdeploy::mmaction::cuda::Transpose(src.data<float>(), GetNative<int*>(_src_strides),
dst.data<float>(), GetNative<int*>(_dst_strides), ndim,
src.size(), (cudaStream_t)stream().GetNative());
return dst;
}
};
MMDEPLOY_REGISTER_FACTORY_FUNC(FormatShapeOp, (cuda, 0), [](std::string input_format) {
return std::make_unique<FormatShapeImpl>(std::move(input_format));
});
} // namespace mmdeploy::mmaction::cuda

View File

@ -1,42 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <stdint.h>
#include <stdio.h>
namespace mmdeploy {
namespace mmaction {
namespace cuda {
template <typename T>
__global__ void transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides,
int ndim, int total) {
int u = blockIdx.x * blockDim.x + threadIdx.x;
if (u >= total) {
return;
}
int remaining = u;
int v = 0;
for (int i = 0; i < ndim; i++) {
int p = remaining / dst_strides[i];
remaining -= p * dst_strides[i];
v += p * src_strides[i];
}
dst[u] = src[v];
}
template <typename T>
void Transpose(const T* src, const int* src_strides, T* dst, const int* dst_strides, int ndim,
int total, cudaStream_t stream) {
int thread_num = 256;
int block_num = (total + thread_num - 1) / thread_num;
transpose<T>
<<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
}
template void Transpose<float>(const float* src, const int* src_strides, float* dst,
const int* dst_strides, int ndim, int total, cudaStream_t stream);
} // namespace cuda
} // namespace mmaction
} // namespace mmdeploy

View File

@ -10,28 +10,47 @@ using namespace std;
namespace mmdeploy::mmaction {
FormatShape::FormatShape(const Value& args) {
auto input_format = args.value("input_format", std::string(""));
if (input_format != "NCHW" && input_format != "NCTHW") {
input_format_ = args.value("input_format", std::string(""));
if (input_format_ != "NCHW" && input_format_ != "NCTHW") {
MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'");
throw_exception(eInvalidArgument);
}
format_ = operation::Managed<mmdeploy::mmaction::FormatShapeOp>::Create(input_format);
permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
}
Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& output, int clip_len,
int num_clips) {
Result<void> FormatShape::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
auto N = static_cast<int64_t>(images.size());
auto H = images[0].shape(1);
auto W = images[0].shape(2);
auto C = images[0].shape(3);
auto& device = operation::gContext().device();
auto& stream = operation::gContext().stream();
TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}};
inputs = Tensor(desc);
auto offset = 0UL;
auto n_item = H * W * C;
auto copy_size = n_item * sizeof(float);
for (int i = 0; i < N; i++) {
auto src_buffer = images[i].buffer();
auto dst_buffer = inputs.buffer();
OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset));
offset += copy_size;
}
return success();
}
Result<void> FormatShape::Format(const std::vector<Tensor>& images, Tensor& output, int clip_len,
int num_clips) {
Tensor inputs;
OUTCOME_TRY(MergeInputs(images, inputs));
if (GetDevice().is_host()) {
OUTCOME_TRY(stream().Wait());
}
// Tensor dst;
if (input_format_ == "NCHW") {
OUTCOME_TRY(output, FormatNCHW(inputs, clip_len, num_clips));
OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output));
}
if (input_format_ == "NCTHW") {
OUTCOME_TRY(output, FormatNCTHW(inputs, clip_len, num_clips));
OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output));
}
TensorShape expand_dim = output.shape();
@ -41,35 +60,13 @@ Result<void> FormatShapeOp::apply(const std::vector<Tensor>& images, Tensor& out
return success();
}
Result<void> FormatShapeOp::MergeInputs(const std::vector<Tensor>& images, Tensor& inputs) {
auto N = static_cast<int64_t>(images.size());
auto H = images[0].shape(1);
auto W = images[0].shape(2);
auto C = images[0].shape(3);
TensorDesc desc = {GetDevice(), DataType::kFLOAT, {N, H, W, C}};
inputs = Tensor(desc);
auto offset = 0UL;
auto n_item = H * W * C;
auto copy_size = n_item * sizeof(float);
for (int i = 0; i < N; i++) {
auto src_buffer = images[i].buffer();
auto dst_buffer = inputs.buffer();
OUTCOME_TRY(stream().Copy(src_buffer, dst_buffer, copy_size, 0, offset));
offset += copy_size;
}
Result<void> FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) {
const vector<int> axes = {0, 3, 1, 2};
OUTCOME_TRY(permute_.Apply(src, dst, axes));
return success();
}
Result<Tensor> FormatShapeOp::FormatNCHW(Tensor& src, int clip_len, int num_clips) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
auto C = src.shape(3);
return Transpose(src, {N, H, W, C}, {0, 3, 1, 2});
}
Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_clips) {
Result<void> FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) {
auto N = src.shape(0);
auto H = src.shape(1);
auto W = src.shape(2);
@ -80,8 +77,9 @@ Result<Tensor> FormatShapeOp::FormatNCTHW(Tensor& src, int clip_len, int num_cli
}
int M = N / L;
src.Reshape({M, L, H, W, C});
return Transpose(src, {M, L, H, W, C}, {0, 4, 1, 2, 3});
const vector<int> axes = {0, 4, 1, 2, 3};
OUTCOME_TRY(permute_.Apply(src, dst, axes));
return success();
}
Result<void> FormatShape::Apply(Value& data) {
@ -119,7 +117,7 @@ Result<void> FormatShape::Apply(Value& data) {
Tensor dst;
data = Value{};
OUTCOME_TRY(format_.Apply(images, dst, clip_len, num_clips));
OUTCOME_TRY(Format(images, dst, clip_len, num_clips));
data["img"] = std::move(dst);
return success();
@ -127,6 +125,4 @@ Result<void> FormatShape::Apply(Value& data) {
MMDEPLOY_REGISTER_TRANSFORM(FormatShape);
MMDEPLOY_DEFINE_REGISTRY(FormatShapeOp);
} // namespace mmdeploy::mmaction

View File

@ -1,50 +1,38 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_
#define MMDEPLOY_SRC_CODEBASE_MMACTION_FORMAT_SHAPE_H_
#ifndef MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_
#define MMDEPLOY_CODEBASE_MMACTION_FORMAT_SHAPE_H_
#include <array>
#include <string>
#include <vector>
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/operation/managed.h"
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/preprocess/transform/transform.h"
namespace mmdeploy::mmaction {
class FormatShapeOp : public operation::Operation {
public:
explicit FormatShapeOp(std::string input_format) : input_format_(std::move(input_format)){};
Result<void> apply(const std::vector<Tensor>& inputs, Tensor& output, int clip_len,
int num_clips);
virtual const Device& GetDevice() = 0;
virtual Result<Tensor> Transpose(Tensor& src, const TensorShape& src_dims,
const std::vector<int>& permutation) = 0;
Result<Tensor> FormatNCHW(Tensor& src, int clip_len, int num_clips);
Result<Tensor> FormatNCTHW(Tensor& src, int clip_len, int num_clips);
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
protected:
std::string input_format_;
};
class FormatShape : public Transform {
public:
explicit FormatShape(const Value& args);
Result<void> Apply(Value& data) override;
private:
operation::Managed<FormatShapeOp> format_;
};
Result<void> Format(const std::vector<Tensor>& images, Tensor& output, int clip_len,
int num_clips);
MMDEPLOY_DECLARE_REGISTRY(FormatShapeOp, std::unique_ptr<FormatShapeOp>(std::string input_format));
Result<void> FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst);
Result<void> FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst);
Result<void> MergeInputs(const std::vector<Tensor>& images, Tensor& inputs);
private:
std::string input_format_;
operation::Managed<operation::Permute> permute_;
};
} // namespace mmdeploy::mmaction

View File

@ -1,7 +1,7 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_
#define MMDEPLOY_SRC_CODEBASE_MMACTION_MMACTION_H_
#ifndef MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_
#define MMDEPLOY_CODEBASE_MMACTION_MMACTION_H_
#include "mmdeploy/codebase/common.h"
#include "mmdeploy/core/device.h"

View File

@ -4,7 +4,9 @@ project(mmdeploy_mmseg)
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
target_link_libraries(${PROJECT_NAME} PRIVATE mmdeploy_opencv_utils)
target_link_libraries(${PROJECT_NAME} PRIVATE
mmdeploy_opencv_utils
mmdeploy_operation)
add_library(mmdeploy::mmseg ALIAS ${PROJECT_NAME})
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} segmentor CACHE INTERNAL "")

View File

@ -12,10 +12,11 @@ namespace mmdeploy::mmseg {
struct SegmentorOutput {
Tensor mask;
Tensor score;
int height;
int width;
int classes;
MMDEPLOY_ARCHIVE_MEMBERS(mask, height, width, classes);
MMDEPLOY_ARCHIVE_MEMBERS(mask, score, height, width, classes);
};
MMDEPLOY_DECLARE_CODEBASE(MMSegmentation, mmseg);

View File

@ -5,6 +5,8 @@
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/operation/managed.h"
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/preprocess/transform/transform.h"
#include "opencv_utils.h"
@ -18,7 +20,10 @@ class ResizeMask : public MMSegmentation {
explicit ResizeMask(const Value &cfg) : MMSegmentation(cfg) {
try {
classes_ = cfg["params"]["num_classes"].get<int>();
with_argmax_ = cfg["params"].value("with_argmax", true);
little_endian_ = IsLittleEndian();
::mmdeploy::operation::Context ctx(Device("cpu"), stream_);
permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
} catch (const std::exception &e) {
MMDEPLOY_ERROR("no ['params']['num_classes'] is specified in cfg: {}", cfg);
throw_exception(eInvalidArgument);
@ -31,40 +36,71 @@ class ResizeMask : public MMSegmentation {
auto mask = inference_result["output"].get<Tensor>();
MMDEPLOY_DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(),
mask.shape(), mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
if (!(mask.shape().size() == 4 && mask.shape(0) == 1)) {
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}
if ((mask.shape(1) != 1) && with_argmax_) {
MMDEPLOY_ERROR("probability feat map with shape: {} requires `with_argmax_=false`",
mask.shape());
return Status(eNotSupported);
}
if ((mask.data_type() != DataType::kFLOAT) && !with_argmax_) {
MMDEPLOY_ERROR("probability feat map only support float32 output");
return Status(eNotSupported);
}
auto channel = (int)mask.shape(1);
auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
auto input_height = preprocess_result["img_metas"]["ori_shape"][1].get<int>();
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
Device host{"cpu"};
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
OUTCOME_TRY(stream_.Wait());
if (!with_argmax_) {
// (C, H, W) -> (H, W, C)
::mmdeploy::operation::Context ctx(host, stream_);
std::vector<int> axes = {0, 2, 3, 1};
OUTCOME_TRY(permute_.Apply(host_tensor, host_tensor, axes));
}
OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type()));
OUTCOME_TRY(auto cv_type, GetCvType(mask.data_type(), channel));
cv::Mat mask_mat(height, width, cv_type, host_tensor.data());
if (mask_mat.channels() > 1) {
cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1);
}
if (mask_mat.type() != CV_32S) {
mask_mat.convertTo(mask_mat, CV_32S);
cv::Mat resized_mask;
cv::Mat resized_score;
Tensor tensor_mask{};
Tensor tensor_score{};
if (with_argmax_) {
// mask
if (mask_mat.channels() > 1) {
cv::extractChannel(mask_mat, mask_mat, little_endian_ ? 0 : mask_mat.channels() - 1);
}
if (mask_mat.type() != CV_32S) {
mask_mat.convertTo(mask_mat, CV_32S);
}
resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest");
tensor_mask = cpu::CVMat2Tensor(resized_mask);
} else {
// score
resized_score = cpu::Resize(mask_mat, input_height, input_width, "bilinear");
tensor_score = cpu::CVMat2Tensor(resized_score);
std::vector<int> axes = {0, 3, 1, 2};
::mmdeploy::operation::Context ctx(host, stream_);
OUTCOME_TRY(permute_.Apply(tensor_score, tensor_score, axes));
}
cv::Mat resized_mask = cpu::Resize(mask_mat, input_height, input_width, "nearest");
SegmentorOutput output{cpu::CVMat2Tensor(resized_mask), input_height, input_width, classes_};
SegmentorOutput output{tensor_mask, tensor_score, input_height, input_width, classes_};
return to_value(output);
}
private:
static Result<int> GetCvType(DataType type) {
static Result<int> GetCvType(DataType type, int channel) {
switch (type) {
case DataType::kFLOAT:
return CV_32F;
return CV_32FC(channel);
case DataType::kINT64:
return CV_32SC2;
case DataType::kINT32:
@ -84,7 +120,9 @@ class ResizeMask : public MMSegmentation {
}
protected:
::mmdeploy::operation::Managed<::mmdeploy::operation::Permute> permute_;
int classes_{};
bool with_argmax_{true};
bool little_endian_;
};

View File

@ -11,7 +11,8 @@ set(SRCS resize.cpp
crop.cpp
flip.cpp
warp_affine.cpp
crop_resize_pad.cpp)
crop_resize_pad.cpp
permute.cpp)
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")

View File

@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/utils/opencv/opencv_utils.h"
namespace mmdeploy::operation::cpu {
class PermuteImpl : public Permute {
public:
explicit PermuteImpl() {}
Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) override {
int ndim = src.shape().size();
if (ndim != axes.size()) {
MMDEPLOY_ERROR("The size of axes should be equal to src, {} vs {}", axes.size(), ndim);
return Status(eInvalidArgument);
}
std::vector<int> axes_vis(ndim, 0);
for (const auto& x : axes) {
if (x < 0 || x >= ndim || axes_vis[x]) {
MMDEPLOY_ERROR("Invalid axes");
return Status(eInvalidArgument);
}
axes_vis[x] = 1;
}
Tensor dst_tensor(src.desc());
auto src_dims = src.shape();
TensorShape dst_dims(ndim);
for (int i = 0; i < src_dims.size(); i++) {
dst_dims[i] = src_dims[axes[i]];
}
dst_tensor.Reshape(dst_dims);
std::vector<int> dst_strides(ndim);
std::vector<int> src_strides(ndim);
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1];
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
}
std::vector<int> tmp(ndim);
for (int i = 0; i < ndim; i++) {
tmp[i] = src_strides[axes[i]];
}
src_strides.swap(tmp);
if (src.data_type() == DataType::kINT8) {
OUTCOME_TRY(PermuteDispatch<uint8_t>(src, dst_tensor, src_strides, dst_strides));
} else if (src.data_type() == DataType::kFLOAT) {
OUTCOME_TRY(PermuteDispatch<float>(src, dst_tensor, src_strides, dst_strides));
} else {
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
return Status(eNotSupported);
}
dst = std::move(dst_tensor);
return success();
}
template <typename T>
Result<void> PermuteDispatch(const Tensor& src, Tensor& dst, const std::vector<int>& src_strides,
const std::vector<int>& dst_strides) {
auto shape = dst.shape();
int ndim = src.shape().size();
std::vector<int> coord(ndim, 0);
auto dst_data = dst.data<T>();
auto src_data = src.data<T>();
int i;
do {
dst_data[0] = src_data[0];
for (i = ndim - 1; i >= 0; i--) {
if (++coord[i] == shape[i]) {
coord[i] = 0;
dst_data -= (shape[i] - 1) * dst_strides[i];
src_data -= (shape[i] - 1) * src_strides[i];
} else {
dst_data += dst_strides[i];
src_data += src_strides[i];
break;
}
}
} while (i >= 0);
return success();
}
};
MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cpu, 0), []() { return std::make_unique<PermuteImpl>(); });
} // namespace mmdeploy::operation::cpu

View File

@ -19,7 +19,9 @@ set(SRCS resize.cpp
crop.cu
flip.cpp
warp_affine.cpp
crop_resize_pad.cpp)
crop_resize_pad.cpp
permute.cpp
permute.cu)
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")

View File

@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/operation/cuda/permute.h"
#include <cuda_runtime.h>
#include "mmdeploy/operation/vision.h"
namespace mmdeploy::operation::cuda {
namespace impl {
template <typename T>
void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides,
int ndim, int total, cudaStream_t stream);
}
class PermuteImpl : public Permute {
public:
explicit PermuteImpl() {}
Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) override {
int ndim = src.shape().size();
if (ndim != axes.size()) {
MMDEPLOY_ERROR("The size of axes should be equal of src, {} vs {}", axes.size(), ndim);
return Status(eInvalidArgument);
}
if (ndim > MAX_PERMUTE_DIM) {
MMDEPLOY_ERROR("Only support ndim < {}", MAX_PERMUTE_DIM);
return Status(eInvalidArgument);
}
std::vector<int> axes_vis(ndim, 0);
for (const auto& x : axes) {
if (x < 0 || x >= ndim || axes_vis[x]) {
MMDEPLOY_ERROR("Invalid axes");
return Status(eInvalidArgument);
}
axes_vis[x] = 1;
}
Tensor dst_tensor(src.desc());
auto src_dims = src.shape();
TensorShape dst_dims(ndim);
for (int i = 0; i < src_dims.size(); i++) {
dst_dims[i] = src_dims[axes[i]];
}
dst_tensor.Reshape(dst_dims);
TensorStride dst_strides;
TensorStride src_strides;
dst_strides[ndim - 1] = src_strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; i--) {
dst_strides[i] = dst_strides[i + 1] * dst_dims[i + 1];
src_strides[i] = src_strides[i + 1] * src_dims[i + 1];
}
TensorStride tmp;
for (int i = 0; i < ndim; i++) {
tmp[i] = src_strides[axes[i]];
}
src_strides = tmp;
if (src.data_type() == DataType::kINT8) {
OUTCOME_TRY(PermuteDispatch<uint8_t>(src, dst_tensor, src_strides, dst_strides));
} else if (src.data_type() == DataType::kFLOAT) {
OUTCOME_TRY(PermuteDispatch<float>(src, dst_tensor, src_strides, dst_strides));
} else {
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
return Status(eNotSupported);
}
dst = std::move(dst_tensor);
return success();
}
template <typename T>
Result<void> PermuteDispatch(const Tensor& src, Tensor& dst, const TensorStride& src_strides,
const TensorStride& dst_strides) {
auto src_data = src.data<T>();
auto dst_data = dst.data<T>();
auto ndim = src.shape().size();
auto total = src.size();
impl::Permute(src_data, src_strides, dst_data, dst_strides, ndim, total,
GetNative<cudaStream_t>(stream()));
return success();
}
};
MMDEPLOY_REGISTER_FACTORY_FUNC(Permute, (cuda, 0),
[]() { return std::make_unique<PermuteImpl>(); });
} // namespace mmdeploy::operation::cuda

View File

@ -0,0 +1,49 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <cstdint>
#include "mmdeploy/operation/cuda/permute.h"
namespace mmdeploy {
namespace operation {
namespace cuda {
namespace impl {
template <typename T>
__global__ void permute(const T* src, const TensorStride src_strides, T* dst,
const TensorStride dst_strides, int ndim, int total) {
int u = blockIdx.x * blockDim.x + threadIdx.x;
if (u >= total) {
return;
}
int remaining = u;
int v = 0;
for (size_t i = 0; i < ndim; i++) {
int p = remaining / dst_strides.v_[i];
remaining -= p * dst_strides.v_[i];
v += p * src_strides.v_[i];
}
dst[u] = src[v];
}
template <typename T>
void Permute(const T* src, const TensorStride& src_strides, T* dst, const TensorStride& dst_strides,
int ndim, int total, cudaStream_t stream) {
int thread_num = 256;
int block_num = (total + thread_num - 1) / thread_num;
permute<T><<<block_num, thread_num, 0, stream>>>(src, src_strides, dst, dst_strides, ndim, total);
}
template void Permute<float>(const float* src, const TensorStride& src_strides, float* dst,
const TensorStride& dst_strides, int ndim, int total,
cudaStream_t stream);
template void Permute<uint8_t>(const uint8_t* src, const TensorStride& src_strides, uint8_t* dst,
const TensorStride& dst_strides, int ndim, int total,
cudaStream_t stream);
} // namespace impl
} // namespace cuda
} // namespace operation
} // namespace mmdeploy

View File

@ -0,0 +1,24 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_OPERATION_CUDA_PERMUTE_H_
#define MMDEPLOY_OPERATION_CUDA_PERMUTE_H_
#include <cuda_runtime.h>
#include <cstdlib>
namespace mmdeploy {
namespace operation {
namespace cuda {
const int MAX_PERMUTE_DIM = 8;
struct TensorStride {
int v_[MAX_PERMUTE_DIM];
int& operator[](size_t idx) { return v_[idx]; }
};
} // namespace cuda
} // namespace operation
} // namespace mmdeploy
#endif // MMDEPLOY_OPERATION_CUDA_PERMUTE_H_

View File

@ -14,5 +14,6 @@ MMDEPLOY_DEFINE_REGISTRY(Crop);
MMDEPLOY_DEFINE_REGISTRY(Flip);
MMDEPLOY_DEFINE_REGISTRY(WarpAffine);
MMDEPLOY_DEFINE_REGISTRY(CropResizePad);
MMDEPLOY_DEFINE_REGISTRY(Permute);
} // namespace mmdeploy::operation

View File

@ -92,6 +92,11 @@ class CropResizePad : public Operation {
};
MMDEPLOY_DECLARE_REGISTRY(CropResizePad, unique_ptr<CropResizePad>());
class Permute : public Operation {
public:
virtual Result<void> apply(const Tensor& src, Tensor& dst, const std::vector<int>& axes) = 0;
};
MMDEPLOY_DECLARE_REGISTRY(Permute, unique_ptr<Permute>());
} // namespace mmdeploy::operation

View File

@ -126,7 +126,7 @@ cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
const std::string& interpolation) {
cv::Mat dst(dst_height, dst_width, src.type());
auto method = GetInterpolationMethod(interpolation).value();
cv::resize(src, dst, dst.size(), method);
cv::resize(src, dst, dst.size(), 0, 0, method);
return dst;
}

View File

@ -75,21 +75,50 @@ namespace image_segmentation
unsafe
{
byte* data = colorMask.DataPointer;
fixed (int* _label = output[0].Mask)
if (output[0].Mask.Length > 0)
{
int* label = _label;
for (int i = 0; i < output[0].Height; i++)
fixed (int* _label = output[0].Mask)
{
for (int j = 0; j < output[0].Width; j++)
int* label = _label;
for (int i = 0; i < output[0].Height; i++)
{
data[0] = palette[*label][0];
data[1] = palette[*label][1];
data[2] = palette[*label][2];
data += 3;
label++;
for (int j = 0; j < output[0].Width; j++)
{
data[0] = palette[*label][0];
data[1] = palette[*label][1];
data[2] = palette[*label][2];
data += 3;
label++;
}
}
}
}
else
{
int pos = 0;
fixed (float* _score = output[0].Score)
{
float *score = _score;
int total = output[0].Height * output[0].Width;
for (int i = 0; i < output[0].Height; i++)
{
for (int j = 0; j < output[0].Width; j++)
{
List<Tuple<float, int>> scores = new List<Tuple<float, int>>();
for (int k = 0; k < output[0].Classes; k++)
{
scores.Add(new Tuple<float, int>(score[k * total + i * output[0].Width + j], k));
}
scores.Sort();
data[0] = palette[scores[^1].Item2][0];
data[1] = palette[scores[^1].Item2][1];
data[2] = palette[scores[^1].Item2][2];
data += 3;
}
}
}
}
}
colorMask = imgs[0] * 0.5 + colorMask * 0.5;

View File

@ -1,6 +1,7 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <fstream>
#include <numeric>
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <random>
@ -59,8 +60,25 @@ int main(int argc, char* argv[]) {
cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3);
int pos = 0;
int total = color_mask.rows * color_mask.cols;
std::vector<int> idxs(result->classes);
for (auto iter = color_mask.begin<cv::Vec3b>(); iter != color_mask.end<cv::Vec3b>(); ++iter) {
*iter = palette[result->mask[pos++]];
// output mask
if (result->mask) {
*iter = palette[result->mask[pos++]];
}
// output score
if (result->score) {
std::iota(idxs.begin(), idxs.end(), 0);
auto k =
std::max_element(idxs.begin(), idxs.end(),
[&](int i, int j) {
return result->score[i * total + pos] < result->score[j * total + pos];
}) -
idxs.begin();
*iter = palette[k];
pos += 1;
}
}
img = img * 0.5 + color_mask * 0.5;

View File

@ -3,6 +3,7 @@
#include "mmdeploy/segmentor.hpp"
#include <fstream>
#include <numeric>
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <random>
@ -47,8 +48,25 @@ int main(int argc, char* argv[]) {
cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3);
int pos = 0;
int total = color_mask.rows * color_mask.cols;
std::vector<int> idxs(result->classes);
for (auto iter = color_mask.begin<cv::Vec3b>(); iter != color_mask.end<cv::Vec3b>(); ++iter) {
*iter = palette[result->mask[pos++]];
// output mask
if (result->mask) {
*iter = palette[result->mask[pos++]];
}
// output score
if (result->score) {
std::iota(idxs.begin(), idxs.end(), 0);
auto k =
std::max_element(idxs.begin(), idxs.end(),
[&](int i, int j) {
return result->score[pos + i * total] < result->score[pos + j * total];
}) -
idxs.begin();
*iter = palette[k];
pos += 1;
}
}
img = img * 0.5 + color_mask * 0.5;

View File

@ -35,6 +35,8 @@ def main():
segmentor = Segmentor(
model_path=args.model_path, device_name=args.device_name, device_id=0)
seg = segmentor(img)
if seg.dtype == np.float32:
seg = np.argmax(seg, axis=0)
palette = get_palette()
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)

View File

@ -0,0 +1,121 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <numeric>
#include "catch.hpp"
#include "mmdeploy/core/mat.h"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/operation/managed.h"
#include "mmdeploy/operation/vision.h"
#include "mmdeploy/preprocess/transform/transform.h"
#include "test_resource.h"
#include "test_utils.h"
using namespace mmdeploy;
using namespace framework;
using namespace std;
using namespace mmdeploy::test;
template <typename T>
bool CheckEqual(const Tensor& res, const vector<T>& expected) {
auto r = res.data<T>();
auto e = expected.data();
for (int i = 0; i < expected.size(); i++) {
if (r[i] != e[i]) {
return false;
}
}
return true;
}
template <typename T>
void TestPermute(const Tensor& src, const vector<int>& axes, const vector<T>& expected) {
auto gResource = MMDeployTestResources::Get();
for (auto const& device_name : gResource.device_names()) {
Device device{device_name.c_str()};
Stream stream{device};
::mmdeploy::operation::Context ctx(device, stream);
auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
Tensor dst;
auto ret = permute.Apply(src, dst, axes);
REQUIRE(!ret.has_error());
const Device kHost{"cpu"};
auto host_tensor = MakeAvailableOnDevice(dst, kHost, stream);
REQUIRE(CheckEqual(host_tensor.value(), expected));
}
}
void TestPermuteWrongArgs(const Tensor& src) {
int sz = src.shape().size();
vector<int> oaxes(sz);
std::iota(oaxes.begin(), oaxes.end(), 0);
auto gResource = MMDeployTestResources::Get();
for (auto const& device_name : gResource.device_names()) {
Device device{device_name.c_str()};
Stream stream{device};
::mmdeploy::operation::Context ctx(device, stream);
auto permute = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create();
Tensor dst;
{
auto axes = oaxes;
axes[0]--;
auto ret = permute.Apply(src, dst, axes);
REQUIRE(ret.has_error());
}
{
auto axes = oaxes;
axes.back()++;
auto ret = permute.Apply(src, dst, axes);
REQUIRE(ret.has_error());
}
{
auto axes = oaxes;
axes[0] = axes[1];
auto ret = permute.Apply(src, dst, axes);
REQUIRE(ret.has_error());
}
}
}
TEST_CASE("operation Permute", "[permute]") {
const Device kHost{"cpu"};
const int kSize = 2 * 3 * 2 * 4;
vector<uint8_t> data(kSize);
std::iota(data.begin(), data.end(), 0); // [0, 48)
TensorDesc desc = {kHost, DataType::kINT8, {kSize}};
Tensor tensor(desc);
memcpy(tensor.data(), data.data(), data.size() * sizeof(uint8_t));
SECTION("permute: wrong axes") {
Tensor src = tensor;
src.Reshape({6, 8});
TestPermuteWrongArgs(src);
}
SECTION("permute: dims 4") {
Tensor src = tensor;
src.Reshape({2, 3, 2, 4});
vector<int> axes = {1, 0, 3, 2};
vector<uint8_t> expected = {0, 4, 1, 5, 2, 6, 3, 7, 24, 28, 25, 29, 26, 30, 27, 31,
8, 12, 9, 13, 10, 14, 11, 15, 32, 36, 33, 37, 34, 38, 35, 39,
16, 20, 17, 21, 18, 22, 19, 23, 40, 44, 41, 45, 42, 46, 43, 47};
Tensor dst(src.desc());
memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t));
TestPermute(src, axes, expected);
}
SECTION("permute: dims 5") {
Tensor src = tensor;
src.Reshape({2, 3, 1, 2, 4});
vector<int> axes = {2, 0, 1, 4, 3};
vector<uint8_t> expected = {0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15,
16, 20, 17, 21, 18, 22, 19, 23, 24, 28, 25, 29, 26, 30, 27, 31,
32, 36, 33, 37, 34, 38, 35, 39, 40, 44, 41, 45, 42, 46, 43, 47};
Tensor dst(src.desc());
memcpy(dst.data(), expected.data(), data.size() * sizeof(uint8_t));
TestPermute(src, axes, expected);
}
}