mmdeploy/csrc/codebase/mmocr/contour_expand.cpp
lzhangzz 5e51739485
[Feature] Support DBNet, PANet and PSENet for SDK, with GPU aided post-processing (#526)
* add SDK support for PANet

* fix panet

* fix panet

* simplify panet

* add PSENet support

* fix-psenet

* add CUDA post-processing for DBNet

* fix dbnet

* fix dbnet

* add cpu support for PANet

* fix panet

* add CUDA support for PANet

* fix panet

* format

* add cpu impls for PSENet

* fix psenet

* add cuda impl for PSENet

* fix psenet

* add param parsing

* simplify impls

* simplify impls

* clean-up

* fix lint

* fix cuda-10 build

* fix cuda-10 build
2022-05-31 21:24:09 +08:00

125 lines
3.9 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved
// Modified from https://github.com/whai362/PSENet
// and
// https://github.com/open-mmlab/mmcv/blob/master/mmcv/ops/csrc/pytorch/contour_expand.cpp
#include <cmath>
#include <iostream>
#include <queue>
#include <vector>
#include "core/tensor.h"
#include "opencv2/opencv.hpp"
namespace mmdeploy::mmocr {
using namespace std;
using cv::Mat_;
class Point2d {
public:
int x;
int y;
Point2d() : x(0), y(0) {}
Point2d(int _x, int _y) : x(_x), y(_y) {}
};
void kernel_dilate(const uint8_t* data, int kernel_num, int height, int width, const int* label_map,
int label_num, const float* score_map, int min_area, Mat_<int32_t>& text_labels,
vector<int>& text_areas, vector<float>& text_scores,
vector<vector<int>>& text_points) {
text_labels = Mat_<int32_t>::zeros(height, width);
text_areas = vector<int>(label_num);
text_scores = vector<float>(label_num);
text_points = vector<vector<int>>(label_num);
for (int x = 0; x < height; ++x) {
for (int y = 0; y < width; ++y) {
int label = label_map[x * width + y];
if (label == 0) continue;
text_areas[label] += 1;
text_scores[label] += score_map[x * width + y];
text_points[label].push_back(y);
text_points[label].push_back(x);
}
}
queue<Point2d> queue, next_queue;
for (int x = 0; x < height; ++x) {
auto row = text_labels[x];
for (int y = 0; y < width; ++y) {
int label = label_map[x * width + y];
if (label == 0) continue;
if (text_areas[label] < min_area) continue;
Point2d point(x, y);
queue.push(point);
row[y] = label;
}
}
const int dx[] = {-1, 1, 0, 0};
const int dy[] = {0, 0, -1, 1};
vector<int> kernel_step(kernel_num);
std::for_each(kernel_step.begin(), kernel_step.end(), [=](int& k) { return k * height * width; });
for (int kernel_id = kernel_num - 2; kernel_id >= 0; --kernel_id) {
while (!queue.empty()) {
Point2d point = queue.front();
queue.pop();
int x = point.x;
int y = point.y;
int label = text_labels[x][y];
bool is_edge = true;
for (int d = 0; d < 4; ++d) {
int tmp_x = x + dx[d];
int tmp_y = y + dy[d];
if (tmp_x < 0 || tmp_x >= height) continue;
if (tmp_y < 0 || tmp_y >= width) continue;
int kernel_value = data[kernel_step[kernel_id] + tmp_x * width + tmp_y];
if (kernel_value == 0) continue;
if (text_labels[tmp_x][tmp_y] > 0) continue;
Point2d point(tmp_x, tmp_y);
queue.push(point);
text_labels[tmp_x][tmp_y] = label;
text_areas[label] += 1;
text_scores[label] += score_map[tmp_x * width + tmp_y];
text_points[label].push_back(tmp_y);
text_points[label].push_back(tmp_x);
is_edge = false;
}
if (is_edge) {
next_queue.push(point);
}
}
swap(queue, next_queue);
}
for (int i = 1; i < label_num; ++i) {
if (text_areas[i]) {
text_scores[i] /= static_cast<float>(text_areas[i]);
}
}
}
void contour_expand(const Mat_<uint8_t>& kernel_masks, const Mat_<int32_t>& kernel_label,
const Mat_<float>& score, int min_kernel_area, int kernel_num,
vector<int>& text_areas, vector<float>& text_scores,
vector<vector<int>>& text_points) {
assert(kernel_masks.cols == kernel_label.total());
assert(score.size() == kernel_label.size());
auto ptr_data = kernel_masks.ptr<uint8_t>();
auto data_score_map = score.ptr<float>();
auto data_label_map = kernel_label.ptr<int32_t>();
vector<vector<int>> text_line;
Mat_<int32_t> text_labels;
kernel_dilate(ptr_data, kernel_masks.rows, kernel_label.rows, kernel_label.cols, data_label_map,
kernel_num, data_score_map, min_kernel_area, text_labels, text_areas, text_scores,
text_points);
}
} // namespace mmdeploy::mmocr