391 lines
15 KiB
C++
391 lines
15 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#include <cctype>
|
|
#include <opencv2/imgcodecs.hpp>
|
|
#include <opencv2/imgproc.hpp>
|
|
|
|
#include "core/device.h"
|
|
#include "core/registry.h"
|
|
#include "core/serialization.h"
|
|
#include "core/tensor.h"
|
|
#include "core/utils/device_utils.h"
|
|
#include "core/utils/formatter.h"
|
|
#include "core/value.h"
|
|
#include "experimental/module_adapter.h"
|
|
#include "mmpose.h"
|
|
#include "opencv_utils.h"
|
|
|
|
namespace mmdeploy::mmpose {
|
|
|
|
using std::string;
|
|
using std::vector;
|
|
|
|
template <class F>
|
|
struct _LoopBody : public cv::ParallelLoopBody {
|
|
F f_;
|
|
_LoopBody(F f) : f_(std::move(f)) {}
|
|
void operator()(const cv::Range& range) const override { f_(range); }
|
|
};
|
|
|
|
std::string to_lower(const std::string& s) {
|
|
std::string t = s;
|
|
std::transform(t.begin(), t.end(), t.begin(), [](unsigned char c) { return std::tolower(c); });
|
|
return t;
|
|
}
|
|
|
|
class TopdownHeatmapBaseHeadDecode : public MMPose {
|
|
public:
|
|
explicit TopdownHeatmapBaseHeadDecode(const Value& config) : MMPose(config) {
|
|
if (config.contains("params")) {
|
|
auto& params = config["params"];
|
|
flip_test_ = params.value("flip_test", flip_test_);
|
|
use_udp_ = params.value("use_udp", use_udp_);
|
|
target_type_ = params.value("target_type", target_type_);
|
|
valid_radius_factor_ = params.value("valid_radius_factor", valid_radius_factor_);
|
|
unbiased_decoding_ = params.value("unbiased_decoding", unbiased_decoding_);
|
|
post_process_ = params.value("post_process", post_process_);
|
|
shift_heatmap_ = params.value("shift_heatmap", shift_heatmap_);
|
|
modulate_kernel_ = params.value("modulate_kernel", modulate_kernel_);
|
|
}
|
|
}
|
|
|
|
Result<Value> operator()(const Value& _data, const Value& _prob) {
|
|
MMDEPLOY_DEBUG("preprocess_result: {}", _data);
|
|
MMDEPLOY_DEBUG("inference_result: {}", _prob);
|
|
|
|
Device cpu_device{"cpu"};
|
|
OUTCOME_TRY(auto heatmap,
|
|
MakeAvailableOnDevice(_prob["output"].get<Tensor>(), cpu_device, stream()));
|
|
OUTCOME_TRY(stream().Wait());
|
|
if (!(heatmap.shape().size() == 4 && heatmap.data_type() == DataType::kFLOAT)) {
|
|
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", heatmap.shape(),
|
|
(int)heatmap.data_type());
|
|
return Status(eNotSupported);
|
|
}
|
|
|
|
auto& img_metas = _data["img_metas"];
|
|
|
|
vector<float> center;
|
|
vector<float> scale;
|
|
from_value(img_metas["center"], center);
|
|
from_value(img_metas["scale"], scale);
|
|
Tensor pred =
|
|
keypoints_from_heatmap(heatmap, center, scale, unbiased_decoding_, post_process_,
|
|
modulate_kernel_, valid_radius_factor_, use_udp_, target_type_);
|
|
|
|
return GetOutput(pred);
|
|
}
|
|
|
|
Value GetOutput(Tensor& pred) {
|
|
PoseDetectorOutput output;
|
|
int K = pred.shape(1);
|
|
float* data = pred.data<float>();
|
|
for (int i = 0; i < K; i++) {
|
|
float x = *(data + 0);
|
|
float y = *(data + 1);
|
|
float s = *(data + 2);
|
|
output.key_points.push_back({{x, y}, s});
|
|
data += 3;
|
|
}
|
|
return to_value(std::move(output));
|
|
}
|
|
|
|
Tensor keypoints_from_heatmap(const Tensor& _heatmap, const vector<float>& center,
|
|
const vector<float>& scale, bool unbiased_decoding,
|
|
const string& post_process, int modulate_kernel,
|
|
float valid_radius_factor, bool use_udp,
|
|
const string& target_type) {
|
|
Tensor heatmap(_heatmap.desc());
|
|
heatmap.CopyFrom(_heatmap, stream()).value();
|
|
stream().Wait().value();
|
|
|
|
int K = heatmap.shape(1);
|
|
int H = heatmap.shape(2);
|
|
int W = heatmap.shape(3);
|
|
|
|
if (post_process == "megvii") {
|
|
heatmap = gaussian_blur(heatmap, modulate_kernel);
|
|
}
|
|
|
|
Tensor pred;
|
|
|
|
if (use_udp) {
|
|
if (to_lower(target_type) == to_lower(string("GaussianHeatMap"))) {
|
|
pred = get_max_pred(heatmap);
|
|
post_dark_udp(pred, heatmap, modulate_kernel);
|
|
} else if (to_lower(target_type) == to_lower(string("CombinedTarget"))) {
|
|
// output channel = 3 * channel_cfg['num_output_channels']
|
|
assert(K % 3 == 0);
|
|
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
|
|
for (int i = r.start; i < r.end; i++) {
|
|
int kt = (i % 3 == 0) ? 2 * modulate_kernel + 1 : modulate_kernel;
|
|
float* data = heatmap.data<float>() + i * H * W;
|
|
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
|
|
cv::GaussianBlur(work, work, {kt, kt}, 0); // inplace
|
|
}
|
|
}});
|
|
float valid_radius = valid_radius_factor_ * H;
|
|
TensorDesc desc = {Device{"cpu"}, DataType::kFLOAT, {1, K / 3, H, W}};
|
|
Tensor offset_x(desc);
|
|
Tensor offset_y(desc);
|
|
Tensor heatmap_(desc);
|
|
{
|
|
// split heatmap
|
|
float* src = heatmap.data<float>();
|
|
float* dst0 = heatmap_.data<float>();
|
|
float* dst1 = offset_x.data<float>();
|
|
float* dst2 = offset_y.data<float>();
|
|
for (int i = 0; i < K / 3; i++) {
|
|
std::copy_n(src, H * W, dst0);
|
|
std::transform(src + H * W, src + 2 * H * W, dst1,
|
|
[=](float& x) { return x * valid_radius; });
|
|
std::transform(src + 2 * H * W, src + 3 * H * W, dst2,
|
|
[=](float& x) { return x * valid_radius; });
|
|
src += 3 * H * W;
|
|
dst0 += H * W;
|
|
dst1 += H * W;
|
|
dst2 += H * W;
|
|
}
|
|
}
|
|
pred = get_max_pred(heatmap_);
|
|
for (int i = 0; i < K / 3; i++) {
|
|
float* data = pred.data<float>() + i * 3;
|
|
int index = *(data + 0) + *(data + 1) * W + H * W * i;
|
|
float* offx = offset_x.data<float>() + index;
|
|
float* offy = offset_y.data<float>() + index;
|
|
*(data + 0) += *offx;
|
|
*(data + 1) += *offy;
|
|
}
|
|
}
|
|
} else {
|
|
pred = get_max_pred(heatmap);
|
|
if (post_process == "unbiased") {
|
|
heatmap = gaussian_blur(heatmap, modulate_kernel);
|
|
float* data = heatmap.data<float>();
|
|
std::for_each(data, data + K * H * W, [](float& v) {
|
|
double _v = std::max((double)v, 1e-10);
|
|
v = std::log(_v);
|
|
});
|
|
for (int i = 0; i < K; i++) {
|
|
taylor(heatmap, pred, i);
|
|
}
|
|
|
|
} else if (post_process != "null") {
|
|
for (int i = 0; i < K; i++) {
|
|
float* data = heatmap.data<float>() + i * W * H;
|
|
auto _data = [&](int y, int x) { return *(data + y * W + x); };
|
|
int px = *(pred.data<float>() + i * 3 + 0);
|
|
int py = *(pred.data<float>() + i * 3 + 1);
|
|
if (1 < px && px < W - 1 && 1 < py && py < H - 1) {
|
|
float v1 = _data(py, px + 1) - _data(py, px - 1);
|
|
float v2 = _data(py + 1, px) - _data(py - 1, px);
|
|
*(pred.data<float>() + i * 3 + 0) += (v1 > 0) ? 0.25 : ((v1 < 0) ? -0.25 : 0);
|
|
*(pred.data<float>() + i * 3 + 1) += (v2 > 0) ? 0.25 : ((v2 < 0) ? -0.25 : 0);
|
|
if (post_process_ == "megvii") {
|
|
*(pred.data<float>() + i * 3 + 0) += 0.5;
|
|
*(pred.data<float>() + i * 3 + 1) += 0.5;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
K = pred.shape(1); // changed if target_type is CombinedTarget
|
|
|
|
// Transform back to the image
|
|
for (int i = 0; i < K; i++) {
|
|
transform_pred(pred, i, center, scale, {W, H}, use_udp);
|
|
}
|
|
|
|
if (post_process_ == "megvii") {
|
|
for (int i = 0; i < K; i++) {
|
|
float* data = pred.data<float>() + i * 3 + 2;
|
|
*data = *data / 255.0 + 0.5;
|
|
}
|
|
}
|
|
|
|
return pred;
|
|
}
|
|
|
|
void post_dark_udp(Tensor& pred, Tensor& heatmap, int kernel) {
|
|
int K = heatmap.shape(1);
|
|
int H = heatmap.shape(2);
|
|
int W = heatmap.shape(3);
|
|
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
|
|
for (int i = r.start; i < r.end; i++) {
|
|
float* data = heatmap.data<float>() + i * H * W;
|
|
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
|
|
cv::GaussianBlur(work, work, {kernel, kernel}, 0); // inplace
|
|
}
|
|
}});
|
|
std::for_each(heatmap.data<float>(), heatmap.data<float>() + K * H * W, [](float& x) {
|
|
x = std::max(0.001f, std::min(50.f, x));
|
|
x = std::log(x);
|
|
});
|
|
auto _heatmap_data = [&](int index, int c) -> float {
|
|
int y = index / (W + 2);
|
|
int x = index % (W + 2);
|
|
y = std::max(0, y - 1);
|
|
x = std::max(0, x - 1);
|
|
return *(heatmap.data<float>() + c * H * W + y * W + x);
|
|
};
|
|
for (int i = 0; i < K; i++) {
|
|
float* data = pred.data<float>() + i * 3;
|
|
int index = *(data + 0) + 1 + (*(data + 1) + 1) * (W + 2);
|
|
float i_ = _heatmap_data(index, i);
|
|
float ix1 = _heatmap_data(index + 1, i);
|
|
float iy1 = _heatmap_data(index + W + 2, i);
|
|
float ix1y1 = _heatmap_data(index + W + 3, i);
|
|
float ix1_y1_ = _heatmap_data(index - W - 3, i);
|
|
float ix1_ = _heatmap_data(index - 1, i);
|
|
float iy1_ = _heatmap_data(index - 2 - W, i);
|
|
float dx = 0.5 * (ix1 - ix1_);
|
|
float dy = 0.5 * (iy1 - iy1_);
|
|
float dxx = ix1 - 2 * i_ + ix1_;
|
|
float dyy = iy1 - 2 * i_ + iy1_;
|
|
float dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_);
|
|
vector<float> _data0 = {dx, dy};
|
|
vector<float> _data1 = {dxx, dxy, dxy, dyy};
|
|
cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data());
|
|
cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data());
|
|
cv::Mat hessianinv = hessian.inv();
|
|
cv::Mat offset = -hessianinv * derivative;
|
|
*(data + 0) += offset.at<float>(0, 0);
|
|
*(data + 1) += offset.at<float>(1, 0);
|
|
}
|
|
}
|
|
|
|
void transform_pred(Tensor& pred, int k, const vector<float>& center, const vector<float>& _scale,
|
|
const vector<int>& output_size, bool use_udp = false) {
|
|
auto scale = _scale;
|
|
scale[0] *= 200;
|
|
scale[1] *= 200;
|
|
|
|
float scale_x, scale_y;
|
|
if (use_udp) {
|
|
scale_x = scale[0] / (output_size[0] - 1.0);
|
|
scale_y = scale[1] / (output_size[1] - 1.0);
|
|
} else {
|
|
scale_x = scale[0] / output_size[0];
|
|
scale_y = scale[1] / output_size[1];
|
|
}
|
|
|
|
float* data = pred.data<float>() + k * 3;
|
|
*(data + 0) = *(data + 0) * scale_x + center[0] - scale[0] * 0.5;
|
|
*(data + 1) = *(data + 1) * scale_y + center[1] - scale[1] * 0.5;
|
|
}
|
|
|
|
void taylor(const Tensor& heatmap, Tensor& pred, int k) {
|
|
int K = heatmap.shape(1);
|
|
int H = heatmap.shape(2);
|
|
int W = heatmap.shape(3);
|
|
int px = *(pred.data<float>() + k * 3 + 0);
|
|
int py = *(pred.data<float>() + k * 3 + 1);
|
|
if (1 < px && px < W - 2 && 1 < py && py < H - 2) {
|
|
float* data = const_cast<float*>(heatmap.data<float>() + k * H * W);
|
|
auto get_data = [&](int r, int c) { return *(data + r * W + c); };
|
|
float dx = 0.5 * (get_data(py, px + 1) - get_data(py, px - 1));
|
|
float dy = 0.5 * (get_data(py + 1, px) - get_data(py - 1, px));
|
|
float dxx = 0.25 * (get_data(py, px + 2) - 2 * get_data(py, px) + get_data(py, px - 2));
|
|
float dxy = 0.25 * (get_data(py + 1, px + 1) - get_data(py - 1, px + 1) -
|
|
get_data(py + 1, px - 1) + get_data(py - 1, px - 1));
|
|
float dyy = 0.25 * (get_data(py + 2, px) - 2 * get_data(py, px) + get_data(py - 2, px));
|
|
|
|
vector<float> _data0 = {dx, dy};
|
|
vector<float> _data1 = {dxx, dxy, dxy, dyy};
|
|
cv::Mat derivative = cv::Mat(2, 1, CV_32FC1, _data0.data());
|
|
cv::Mat hessian = cv::Mat(2, 2, CV_32FC1, _data1.data());
|
|
if (std::fabs(dxx * dyy - dxy * dxy) > 1e-6) {
|
|
cv::Mat hessianinv = hessian.inv();
|
|
cv::Mat offset = -hessianinv * derivative;
|
|
*(pred.data<float>() + k * 3 + 0) += offset.at<float>(0, 0);
|
|
*(pred.data<float>() + k * 3 + 1) += offset.at<float>(1, 0);
|
|
}
|
|
}
|
|
}
|
|
|
|
Tensor gaussian_blur(const Tensor& _heatmap, int kernel) {
|
|
assert(kernel % 2 == 1);
|
|
|
|
auto desc = _heatmap.desc();
|
|
Tensor heatmap(desc);
|
|
|
|
int K = _heatmap.shape(1);
|
|
int H = _heatmap.shape(2);
|
|
int W = _heatmap.shape(3);
|
|
int num_points = H * W;
|
|
|
|
int border = (kernel - 1) / 2;
|
|
|
|
for (int i = 0; i < K; i++) {
|
|
int offset = i * H * W;
|
|
float* data = const_cast<float*>(_heatmap.data<float>()) + offset;
|
|
float origin_max = *std::max_element(data, data + num_points);
|
|
cv::Mat work = cv::Mat(H + 2 * border, W + 2 * border, CV_32FC1, cv::Scalar{});
|
|
cv::Mat curr = cv::Mat(H, W, CV_32FC1, data);
|
|
cv::Rect roi = {border, border, W, H};
|
|
curr.copyTo(work(roi));
|
|
cv::GaussianBlur(work, work, {kernel, kernel}, 0);
|
|
cv::Mat valid = work(roi).clone();
|
|
float cur_max = *std::max_element((float*)valid.data, (float*)valid.data + num_points);
|
|
float* dst = heatmap.data<float>() + offset;
|
|
std::transform((float*)valid.data, (float*)valid.data + num_points, dst,
|
|
[&](float v) { return v * origin_max / cur_max; });
|
|
}
|
|
return heatmap;
|
|
}
|
|
|
|
Tensor get_max_pred(const Tensor& heatmap) {
|
|
int K = heatmap.shape(1);
|
|
int H = heatmap.shape(2);
|
|
int W = heatmap.shape(3);
|
|
int num_points = H * W;
|
|
TensorDesc pred_desc = {Device{"cpu"}, DataType::kFLOAT, {1, K, 3}};
|
|
Tensor pred(pred_desc);
|
|
|
|
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
|
|
for (int i = r.start; i < r.end; i++) {
|
|
float* src_data = const_cast<float*>(heatmap.data<float>()) + i * H * W;
|
|
cv::Mat mat = cv::Mat(H, W, CV_32FC1, src_data);
|
|
double min_val, max_val;
|
|
cv::Point min_loc, max_loc;
|
|
cv::minMaxLoc(mat, &min_val, &max_val, &min_loc, &max_loc);
|
|
float* dst_data = pred.data<float>() + i * 3;
|
|
*(dst_data + 0) = -1;
|
|
*(dst_data + 1) = -1;
|
|
*(dst_data + 2) = max_val;
|
|
if (max_val > 0.0) {
|
|
*(dst_data + 0) = max_loc.x;
|
|
*(dst_data + 1) = max_loc.y;
|
|
}
|
|
}
|
|
}});
|
|
|
|
return pred;
|
|
}
|
|
|
|
private:
|
|
bool flip_test_{true};
|
|
bool shift_heatmap_{true};
|
|
string post_process_ = {"default"};
|
|
int modulate_kernel_{11};
|
|
bool unbiased_decoding_{false};
|
|
float valid_radius_factor_{0.0546875f};
|
|
bool use_udp_{false};
|
|
string target_type_{"GaussianHeatmap"};
|
|
};
|
|
|
|
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapBaseHeadDecode);
|
|
|
|
// decode process is same
|
|
using TopdownHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode;
|
|
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapSimpleHeadDecode);
|
|
using TopdownHeatmapMultiStageHeadDecode = TopdownHeatmapBaseHeadDecode;
|
|
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMultiStageHeadDecode);
|
|
using ViPNASHeatmapSimpleHeadDecode = TopdownHeatmapBaseHeadDecode;
|
|
REGISTER_CODEBASE_COMPONENT(MMPose, ViPNASHeatmapSimpleHeadDecode);
|
|
using TopdownHeatmapMSMUHeadDecode = TopdownHeatmapBaseHeadDecode;
|
|
REGISTER_CODEBASE_COMPONENT(MMPose, TopdownHeatmapMSMUHeadDecode);
|
|
|
|
} // namespace mmdeploy::mmpose
|