[Enhancement] Optimize pose tracker (#1460)
* sync master * suppress overlapped tracks * add CUDA WarpAffine * export symbols * fix linkage * update pose tracker * clean-up * fix MSVC build * fix MSVC build * add ffmpeg cli commandpull/1593/head
parent
f62352a5fa
commit
20e0563682
|
@ -6,7 +6,9 @@ project(mmdeploy_mmpose)
|
||||||
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
|
file(GLOB_RECURSE SRCS ${CMAKE_CURRENT_SOURCE_DIR} "*.cpp")
|
||||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||||
target_link_libraries(${PROJECT_NAME} PRIVATE
|
target_link_libraries(${PROJECT_NAME} PRIVATE
|
||||||
mmdeploy::transform mmdeploy_opencv_utils)
|
mmdeploy::transform
|
||||||
|
mmdeploy_operation
|
||||||
|
mmdeploy_opencv_utils)
|
||||||
add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME})
|
add_library(mmdeploy::mmpose ALIAS ${PROJECT_NAME})
|
||||||
|
|
||||||
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} pose_detector CACHE INTERNAL "")
|
set(MMDEPLOY_TASKS ${MMDEPLOY_TASKS} pose_detector CACHE INTERNAL "")
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
#include "mmdeploy/core/tensor.h"
|
#include "mmdeploy/core/tensor.h"
|
||||||
#include "mmdeploy/core/utils/device_utils.h"
|
#include "mmdeploy/core/utils/device_utils.h"
|
||||||
#include "mmdeploy/core/utils/formatter.h"
|
#include "mmdeploy/core/utils/formatter.h"
|
||||||
|
#include "mmdeploy/operation/managed.h"
|
||||||
|
#include "mmdeploy/operation/vision.h"
|
||||||
#include "mmdeploy/preprocess/transform/transform.h"
|
#include "mmdeploy/preprocess/transform/transform.h"
|
||||||
#include "opencv2/imgproc.hpp"
|
#include "opencv2/imgproc.hpp"
|
||||||
#include "opencv_utils.h"
|
#include "opencv_utils.h"
|
||||||
|
@ -32,6 +34,7 @@ class TopDownAffine : public transform::Transform {
|
||||||
stream_ = args["context"]["stream"].get<Stream>();
|
stream_ = args["context"]["stream"].get<Stream>();
|
||||||
assert(args.contains("image_size"));
|
assert(args.contains("image_size"));
|
||||||
from_value(args["image_size"], image_size_);
|
from_value(args["image_size"], image_size_);
|
||||||
|
warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
|
||||||
}
|
}
|
||||||
|
|
||||||
~TopDownAffine() override = default;
|
~TopDownAffine() override = default;
|
||||||
|
@ -39,11 +42,7 @@ class TopDownAffine : public transform::Transform {
|
||||||
Result<void> Apply(Value& data) override {
|
Result<void> Apply(Value& data) override {
|
||||||
MMDEPLOY_DEBUG("top_down_affine input: {}", data);
|
MMDEPLOY_DEBUG("top_down_affine input: {}", data);
|
||||||
|
|
||||||
Device host{"cpu"};
|
auto img = data["img"].get<Tensor>();
|
||||||
auto _img = data["img"].get<Tensor>();
|
|
||||||
OUTCOME_TRY(auto img, MakeAvailableOnDevice(_img, host, stream_));
|
|
||||||
stream_.Wait().value();
|
|
||||||
auto src = cpu::Tensor2CVMat(img);
|
|
||||||
|
|
||||||
// prepare data
|
// prepare data
|
||||||
vector<float> bbox;
|
vector<float> bbox;
|
||||||
|
@ -62,21 +61,20 @@ class TopDownAffine : public transform::Transform {
|
||||||
|
|
||||||
auto r = data["rotation"].get<float>();
|
auto r = data["rotation"].get<float>();
|
||||||
|
|
||||||
cv::Mat dst;
|
Tensor dst;
|
||||||
if (use_udp_) {
|
if (use_udp_) {
|
||||||
cv::Mat trans =
|
cv::Mat trans =
|
||||||
GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f},
|
GetWarpMatrix(r, {c[0] * 2.f, c[1] * 2.f}, {image_size_[0] - 1.f, image_size_[1] - 1.f},
|
||||||
{s[0] * 200.f, s[1] * 200.f});
|
{s[0] * 200.f, s[1] * 200.f});
|
||||||
|
OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
|
||||||
cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
|
|
||||||
} else {
|
} else {
|
||||||
cv::Mat trans =
|
cv::Mat trans =
|
||||||
GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]});
|
GetAffineTransform({c[0], c[1]}, {s[0], s[1]}, r, {image_size_[0], image_size_[1]});
|
||||||
cv::warpAffine(src, dst, trans, {image_size_[0], image_size_[1]}, cv::INTER_LINEAR);
|
OUTCOME_TRY(warp_affine_.Apply(img, dst, trans.ptr<float>(), image_size_[1], image_size_[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
data["img"] = cpu::CVMat2Tensor(dst);
|
data["img_shape"] = {1, image_size_[1], image_size_[0], dst.shape(3)};
|
||||||
data["img_shape"] = {1, image_size_[1], image_size_[0], dst.channels()};
|
data["img"] = std::move(dst);
|
||||||
data["center"] = to_value(c);
|
data["center"] = to_value(c);
|
||||||
data["scale"] = to_value(s);
|
data["scale"] = to_value(s);
|
||||||
MMDEPLOY_DEBUG("output: {}", data);
|
MMDEPLOY_DEBUG("output: {}", data);
|
||||||
|
@ -106,7 +104,7 @@ class TopDownAffine : public transform::Transform {
|
||||||
theta = theta * 3.1415926 / 180;
|
theta = theta * 3.1415926 / 180;
|
||||||
float scale_x = size_dst.width / size_target.width;
|
float scale_x = size_dst.width / size_target.width;
|
||||||
float scale_y = size_dst.height / size_target.height;
|
float scale_y = size_dst.height / size_target.height;
|
||||||
cv::Mat matrix = cv::Mat(2, 3, CV_32FC1);
|
cv::Mat matrix = cv::Mat(2, 3, CV_32F);
|
||||||
matrix.at<float>(0, 0) = std::cos(theta) * scale_x;
|
matrix.at<float>(0, 0) = std::cos(theta) * scale_x;
|
||||||
matrix.at<float>(0, 1) = -std::sin(theta) * scale_x;
|
matrix.at<float>(0, 1) = -std::sin(theta) * scale_x;
|
||||||
matrix.at<float>(0, 2) =
|
matrix.at<float>(0, 2) =
|
||||||
|
@ -142,6 +140,7 @@ class TopDownAffine : public transform::Transform {
|
||||||
|
|
||||||
cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points)
|
cv::Mat trans = inv ? cv::getAffineTransform(dst_points, src_points)
|
||||||
: cv::getAffineTransform(src_points, dst_points);
|
: cv::getAffineTransform(src_points, dst_points);
|
||||||
|
trans.convertTo(trans, CV_32F);
|
||||||
return trans;
|
return trans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -160,6 +159,7 @@ class TopDownAffine : public transform::Transform {
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
operation::Managed<operation::WarpAffine> warp_affine_;
|
||||||
bool use_udp_{false};
|
bool use_udp_{false};
|
||||||
vector<int> image_size_;
|
vector<int> image_size_;
|
||||||
std::string backend_;
|
std::string backend_;
|
||||||
|
|
|
@ -9,7 +9,8 @@ set(SRCS resize.cpp
|
||||||
hwc2chw.cpp
|
hwc2chw.cpp
|
||||||
normalize.cpp
|
normalize.cpp
|
||||||
crop.cpp
|
crop.cpp
|
||||||
flip.cpp)
|
flip.cpp
|
||||||
|
warp_affine.cpp)
|
||||||
|
|
||||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ namespace mmdeploy::operation::cpu {
|
||||||
|
|
||||||
class ResizeImpl : public Resize {
|
class ResizeImpl : public Resize {
|
||||||
public:
|
public:
|
||||||
ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
|
explicit ResizeImpl(std::string interp) : interp_(std::move(interp)) {}
|
||||||
|
|
||||||
Result<void> apply(const Tensor& src, Tensor& dst, int dst_h, int dst_w) override {
|
Result<void> apply(const Tensor& src, Tensor& dst, int dst_h, int dst_w) override {
|
||||||
auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
|
auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
|
||||||
|
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
#include "mmdeploy/operation/vision.h"
|
||||||
|
#include "mmdeploy/utils/opencv/opencv_utils.h"
|
||||||
|
|
||||||
|
namespace mmdeploy::operation::cpu {
|
||||||
|
|
||||||
|
class WarpAffineImpl : public WarpAffine {
|
||||||
|
public:
|
||||||
|
explicit WarpAffineImpl(int method) : method_(method) {}
|
||||||
|
|
||||||
|
Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
|
||||||
|
int dst_w) override {
|
||||||
|
auto src_mat = mmdeploy::cpu::Tensor2CVMat(src);
|
||||||
|
cv::Mat_<float> _matrix(2, 3, const_cast<float*>(affine_matrix));
|
||||||
|
auto dst_mat = mmdeploy::cpu::WarpAffine(src_mat, _matrix, dst_h, dst_w, method_);
|
||||||
|
dst = mmdeploy::cpu::CVMat2Tensor(dst_mat);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int method_;
|
||||||
|
};
|
||||||
|
|
||||||
|
MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cpu, 0), [](const string_view& interp) {
|
||||||
|
return std::make_unique<WarpAffineImpl>(::mmdeploy::cpu::GetInterpolationMethod(interp).value());
|
||||||
|
});
|
||||||
|
|
||||||
|
} // namespace mmdeploy::operation::cpu
|
|
@ -17,7 +17,8 @@ set(SRCS resize.cpp
|
||||||
normalize.cu
|
normalize.cu
|
||||||
crop.cpp
|
crop.cpp
|
||||||
crop.cu
|
crop.cu
|
||||||
flip.cpp)
|
flip.cpp
|
||||||
|
warp_affine.cpp)
|
||||||
|
|
||||||
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
mmdeploy_add_module(${PROJECT_NAME} "${SRCS}")
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
// Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
|
||||||
|
#include "mmdeploy/core/utils/formatter.h"
|
||||||
|
#include "mmdeploy/operation/vision.h"
|
||||||
|
#include "ppl/cv/cuda/warpaffine.h"
|
||||||
|
|
||||||
|
namespace mmdeploy::operation::cuda {
|
||||||
|
|
||||||
|
class WarpAffineImpl : public WarpAffine {
|
||||||
|
public:
|
||||||
|
explicit WarpAffineImpl(ppl::cv::InterpolationType interp) : interp_(interp) {}
|
||||||
|
|
||||||
|
Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6], int dst_h,
|
||||||
|
int dst_w) override {
|
||||||
|
assert(src.device() == device());
|
||||||
|
|
||||||
|
TensorDesc desc{device(), src.data_type(), {1, dst_h, dst_w, src.shape(3)}, src.name()};
|
||||||
|
Tensor dst_tensor(desc);
|
||||||
|
|
||||||
|
const auto m = affine_matrix;
|
||||||
|
auto inv = Invert(affine_matrix);
|
||||||
|
|
||||||
|
auto cuda_stream = GetNative<cudaStream_t>(stream());
|
||||||
|
if (src.data_type() == DataType::kINT8) {
|
||||||
|
OUTCOME_TRY(Dispatch<uint8_t>(src, dst_tensor, inv.data(), cuda_stream));
|
||||||
|
} else if (src.data_type() == DataType::kFLOAT) {
|
||||||
|
OUTCOME_TRY(Dispatch<float>(src, dst_tensor, inv.data(), cuda_stream));
|
||||||
|
} else {
|
||||||
|
MMDEPLOY_ERROR("unsupported data type {}", src.data_type());
|
||||||
|
return Status(eNotSupported);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = std::move(dst_tensor);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// ppl.cv uses inverted transform
|
||||||
|
// https://github.com/opencv/opencv/blob/bc6544c0bcfa9ca5db5e0d0551edf5c8e7da3852/modules/imgproc/src/imgwarp.cpp#L3478
|
||||||
|
static std::array<float, 6> Invert(const float affine_matrix[6]) {
|
||||||
|
const auto* M = affine_matrix;
|
||||||
|
std::array<float, 6> inv{};
|
||||||
|
auto iM = inv.data();
|
||||||
|
|
||||||
|
auto D = M[0] * M[3 + 1] - M[1] * M[3];
|
||||||
|
D = D != 0.f ? 1.f / D : 0.f;
|
||||||
|
auto A11 = M[3 + 1] * D, A22 = M[0] * D, A12 = -M[1] * D, A21 = -M[3] * D;
|
||||||
|
auto b1 = -A11 * M[2] - A12 * M[3 + 2];
|
||||||
|
auto b2 = -A21 * M[2] - A22 * M[3 + 2];
|
||||||
|
|
||||||
|
iM[0] = A11;
|
||||||
|
iM[1] = A12;
|
||||||
|
iM[2] = b1;
|
||||||
|
iM[3] = A21;
|
||||||
|
iM[3 + 1] = A22;
|
||||||
|
iM[3 + 2] = b2;
|
||||||
|
|
||||||
|
return inv;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
auto Select(int channels) -> decltype(&ppl::cv::cuda::WarpAffine<T, 1>) {
|
||||||
|
switch (channels) {
|
||||||
|
case 1:
|
||||||
|
return &ppl::cv::cuda::WarpAffine<T, 1>;
|
||||||
|
case 3:
|
||||||
|
return &ppl::cv::cuda::WarpAffine<T, 3>;
|
||||||
|
case 4:
|
||||||
|
return &ppl::cv::cuda::WarpAffine<T, 4>;
|
||||||
|
default:
|
||||||
|
MMDEPLOY_ERROR("unsupported channels {}", channels);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
Result<void> Dispatch(const Tensor& src, Tensor& dst, const float affine_matrix[6],
|
||||||
|
cudaStream_t stream) {
|
||||||
|
int h = (int)src.shape(1);
|
||||||
|
int w = (int)src.shape(2);
|
||||||
|
int c = (int)src.shape(3);
|
||||||
|
int dst_h = (int)dst.shape(1);
|
||||||
|
int dst_w = (int)dst.shape(2);
|
||||||
|
|
||||||
|
auto input = src.data<T>();
|
||||||
|
auto output = dst.data<T>();
|
||||||
|
|
||||||
|
ppl::common::RetCode ret = 0;
|
||||||
|
|
||||||
|
if (auto warp_affine = Select<T>(c); warp_affine) {
|
||||||
|
ret = warp_affine(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output, affine_matrix,
|
||||||
|
interp_, ppl::cv::BORDER_CONSTANT, 0);
|
||||||
|
} else {
|
||||||
|
return Status(eNotSupported);
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret == 0 ? success() : Result<void>(Status(eFail));
|
||||||
|
}
|
||||||
|
|
||||||
|
ppl::cv::InterpolationType interp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
static auto Create(const string_view& interp) {
|
||||||
|
ppl::cv::InterpolationType type{};
|
||||||
|
if (interp == "bilinear") {
|
||||||
|
type = ppl::cv::InterpolationType::INTERPOLATION_LINEAR;
|
||||||
|
} else if (interp == "nearest") {
|
||||||
|
type = ppl::cv::InterpolationType::INTERPOLATION_NEAREST_POINT;
|
||||||
|
} else {
|
||||||
|
MMDEPLOY_ERROR("unsupported interpolation method: {}", interp);
|
||||||
|
throw_exception(eNotSupported);
|
||||||
|
}
|
||||||
|
return std::make_unique<WarpAffineImpl>(type);
|
||||||
|
}
|
||||||
|
|
||||||
|
MMDEPLOY_REGISTER_FACTORY_FUNC(WarpAffine, (cuda, 0), Create);
|
||||||
|
|
||||||
|
} // namespace mmdeploy::operation::cuda
|
|
@ -12,5 +12,6 @@ MMDEPLOY_DEFINE_REGISTRY(HWC2CHW);
|
||||||
MMDEPLOY_DEFINE_REGISTRY(Normalize);
|
MMDEPLOY_DEFINE_REGISTRY(Normalize);
|
||||||
MMDEPLOY_DEFINE_REGISTRY(Crop);
|
MMDEPLOY_DEFINE_REGISTRY(Crop);
|
||||||
MMDEPLOY_DEFINE_REGISTRY(Flip);
|
MMDEPLOY_DEFINE_REGISTRY(Flip);
|
||||||
|
MMDEPLOY_DEFINE_REGISTRY(WarpAffine);
|
||||||
|
|
||||||
} // namespace mmdeploy::operation
|
} // namespace mmdeploy::operation
|
||||||
|
|
|
@ -76,7 +76,13 @@ class Flip : public Operation {
|
||||||
};
|
};
|
||||||
MMDEPLOY_DECLARE_REGISTRY(Flip, unique_ptr<Flip>(int flip_code));
|
MMDEPLOY_DECLARE_REGISTRY(Flip, unique_ptr<Flip>(int flip_code));
|
||||||
|
|
||||||
// TODO: warp affine
|
// 2x3 OpenCV affine matrix, row major
|
||||||
|
class WarpAffine : public Operation {
|
||||||
|
public:
|
||||||
|
virtual Result<void> apply(const Tensor& src, Tensor& dst, const float affine_matrix[6],
|
||||||
|
int dst_h, int dst_w) = 0;
|
||||||
|
};
|
||||||
|
MMDEPLOY_DECLARE_REGISTRY(WarpAffine, unique_ptr<WarpAffine>(const string_view& interp));
|
||||||
|
|
||||||
} // namespace mmdeploy::operation
|
} // namespace mmdeploy::operation
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,12 @@ class PrepareImage : public Transform {
|
||||||
|
|
||||||
Result<void> Apply(Value& data) override {
|
Result<void> Apply(Value& data) override {
|
||||||
MMDEPLOY_DEBUG("input: {}", data);
|
MMDEPLOY_DEBUG("input: {}", data);
|
||||||
|
|
||||||
|
// early exit
|
||||||
|
if (data.contains("img") && data["img"].is_any<Tensor>()) {
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
assert(data.contains("ori_img"));
|
assert(data.contains("ori_img"));
|
||||||
|
|
||||||
Mat src_mat = data["ori_img"].get<Mat>();
|
Mat src_mat = data["ori_img"].get<Mat>();
|
||||||
|
|
|
@ -106,23 +106,34 @@ Tensor CVMat2Tensor(const cv::Mat& mat) {
|
||||||
return Tensor{desc, data};
|
return Tensor{desc, data};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Result<int> GetInterpolationMethod(const std::string_view& method) {
|
||||||
|
if (method == "bilinear") {
|
||||||
|
return cv::INTER_LINEAR;
|
||||||
|
} else if (method == "nearest") {
|
||||||
|
return cv::INTER_NEAREST;
|
||||||
|
} else if (method == "area") {
|
||||||
|
return cv::INTER_AREA;
|
||||||
|
} else if (method == "bicubic") {
|
||||||
|
return cv::INTER_CUBIC;
|
||||||
|
} else if (method == "lanczos") {
|
||||||
|
return cv::INTER_LANCZOS4;
|
||||||
|
}
|
||||||
|
MMDEPLOY_ERROR("unsupported interpolation method: {}", method);
|
||||||
|
return Status(eNotSupported);
|
||||||
|
}
|
||||||
|
|
||||||
cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
|
cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
|
||||||
const std::string& interpolation) {
|
const std::string& interpolation) {
|
||||||
cv::Mat dst(dst_height, dst_width, src.type());
|
cv::Mat dst(dst_height, dst_width, src.type());
|
||||||
if (interpolation == "bilinear") {
|
auto method = GetInterpolationMethod(interpolation).value();
|
||||||
cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LINEAR);
|
cv::resize(src, dst, dst.size(), method);
|
||||||
} else if (interpolation == "nearest") {
|
return dst;
|
||||||
cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_NEAREST);
|
}
|
||||||
} else if (interpolation == "area") {
|
|
||||||
cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_AREA);
|
cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height, int dst_width,
|
||||||
} else if (interpolation == "bicubic") {
|
int interpolation) {
|
||||||
cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_CUBIC);
|
cv::Mat dst(dst_height, dst_width, src.type());
|
||||||
} else if (interpolation == "lanczos") {
|
cv::warpAffine(src, dst, affine_matrix, dst.size(), interpolation);
|
||||||
cv::resize(src, dst, dst.size(), 0, 0, cv::INTER_LANCZOS4);
|
|
||||||
} else {
|
|
||||||
MMDEPLOY_ERROR("{} interpolation is not supported", interpolation);
|
|
||||||
assert(0);
|
|
||||||
}
|
|
||||||
return dst;
|
return dst;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,8 @@ MMDEPLOY_API cv::Mat Tensor2CVMat(const framework::Tensor& tensor);
|
||||||
MMDEPLOY_API framework::Mat CVMat2Mat(const cv::Mat& mat, PixelFormat format);
|
MMDEPLOY_API framework::Mat CVMat2Mat(const cv::Mat& mat, PixelFormat format);
|
||||||
MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat);
|
MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat);
|
||||||
|
|
||||||
|
MMDEPLOY_API Result<int> GetInterpolationMethod(const std::string_view& method);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief resize an image to specified size
|
* @brief resize an image to specified size
|
||||||
*
|
*
|
||||||
|
@ -29,6 +31,9 @@ MMDEPLOY_API framework::Tensor CVMat2Tensor(const cv::Mat& mat);
|
||||||
MMDEPLOY_API cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
|
MMDEPLOY_API cv::Mat Resize(const cv::Mat& src, int dst_height, int dst_width,
|
||||||
const std::string& interpolation);
|
const std::string& interpolation);
|
||||||
|
|
||||||
|
MMDEPLOY_API cv::Mat WarpAffine(const cv::Mat& src, const cv::Mat& affine_matrix, int dst_height,
|
||||||
|
int dst_width, int interpolation);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief crop an image
|
* @brief crop an image
|
||||||
*
|
*
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue