optimize mmpose postprocess (#1887)

pull/1864/merge
Chen Xin 2023-03-21 11:06:18 +08:00 committed by GitHub
parent 34c68663b6
commit 06dac732c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 27 additions and 44 deletions

View File

@ -18,13 +18,6 @@ namespace mmdeploy::mmpose {
using std::string;
using std::vector;
template <class F>
struct _LoopBody : public cv::ParallelLoopBody {
F f_;
_LoopBody(F f) : f_(std::move(f)) {}
void operator()(const cv::Range& range) const override { f_(range); }
};
std::string to_lower(const std::string& s) {
std::string t = s;
std::transform(t.begin(), t.end(), t.begin(), [](unsigned char c) { return std::tolower(c); });
@ -88,15 +81,11 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
return to_value(std::move(output));
}
Tensor keypoints_from_heatmap(const Tensor& _heatmap, const vector<float>& center,
Tensor keypoints_from_heatmap(Tensor& heatmap, const vector<float>& center,
const vector<float>& scale, bool unbiased_decoding,
const string& post_process, int modulate_kernel,
float valid_radius_factor, bool use_udp,
const string& target_type) {
Tensor heatmap(_heatmap.desc());
heatmap.CopyFrom(_heatmap, stream()).value();
stream().Wait().value();
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
@ -114,14 +103,12 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
} else if (to_lower(target_type) == to_lower(string("CombinedTarget"))) {
// output channel = 3 * channel_cfg['num_output_channels']
assert(K % 3 == 0);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
int kt = (i % 3 == 0) ? 2 * modulate_kernel + 1 : modulate_kernel;
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kt, kt}, 0); // inplace
}
}});
for (int i = 0; i < K; i++) {
int kt = (i % 3 == 0) ? 2 * modulate_kernel + 1 : modulate_kernel;
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kt, kt}, 0); // inplace
}
float valid_radius = valid_radius_factor_ * H;
TensorDesc desc = {Device{"cpu"}, DataType::kFLOAT, {1, K / 3, H, W}};
Tensor offset_x(desc);
@ -209,13 +196,11 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
int K = heatmap.shape(1);
int H = heatmap.shape(2);
int W = heatmap.shape(3);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kernel, kernel}, 0); // inplace
}
}});
for (int i = 0; i < K; i++) {
float* data = heatmap.data<float>() + i * H * W;
cv::Mat work = cv::Mat(H, W, CV_32FC(1), data);
cv::GaussianBlur(work, work, {kernel, kernel}, 0); // inplace
}
std::for_each(heatmap.data<float>(), heatmap.data<float>() + K * H * W, [](float& x) {
x = std::max(0.001f, std::min(50.f, x));
x = std::log(x);
@ -341,23 +326,21 @@ class TopdownHeatmapBaseHeadDecode : public MMPose {
TensorDesc pred_desc = {Device{"cpu"}, DataType::kFLOAT, {1, K, 3}};
Tensor pred(pred_desc);
cv::parallel_for_(cv::Range(0, K), _LoopBody{[&](const cv::Range& r) {
for (int i = r.start; i < r.end; i++) {
float* src_data = const_cast<float*>(heatmap.data<float>()) + i * H * W;
cv::Mat mat = cv::Mat(H, W, CV_32FC1, src_data);
double min_val, max_val;
cv::Point min_loc, max_loc;
cv::minMaxLoc(mat, &min_val, &max_val, &min_loc, &max_loc);
float* dst_data = pred.data<float>() + i * 3;
*(dst_data + 0) = -1;
*(dst_data + 1) = -1;
*(dst_data + 2) = max_val;
if (max_val > 0.0) {
*(dst_data + 0) = max_loc.x;
*(dst_data + 1) = max_loc.y;
}
}
}});
for (int i = 0; i < K; i++) {
float* src_data = const_cast<float*>(heatmap.data<float>()) + i * H * W;
cv::Mat mat = cv::Mat(H, W, CV_32FC1, src_data);
double min_val, max_val;
cv::Point min_loc, max_loc;
cv::minMaxLoc(mat, &min_val, &max_val, &min_loc, &max_loc);
float* dst_data = pred.data<float>() + i * 3;
*(dst_data + 0) = -1;
*(dst_data + 1) = -1;
*(dst_data + 2) = max_val;
if (max_val > 0.0) {
*(dst_data + 0) = max_loc.x;
*(dst_data + 1) = max_loc.y;
}
}
return pred;
}