2021-12-07 10:57:55 +08:00
|
|
|
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
|
2022-05-31 21:24:09 +08:00
|
|
|
#include "codebase/mmocr/dbnet.h"
|
|
|
|
|
2021-12-07 10:57:55 +08:00
|
|
|
#include <opencv2/imgproc.hpp>
|
|
|
|
|
|
|
|
#include "clipper.hpp"
|
|
|
|
#include "core/device.h"
|
|
|
|
#include "core/tensor.h"
|
|
|
|
#include "core/utils/formatter.h"
|
|
|
|
#include "experimental/module_adapter.h"
|
|
|
|
#include "mmocr.h"
|
|
|
|
|
2022-05-31 21:24:09 +08:00
|
|
|
namespace mmdeploy {
|
|
|
|
|
|
|
|
namespace mmocr {
|
2021-12-07 10:57:55 +08:00
|
|
|
|
|
|
|
using std::string;
|
|
|
|
using std::vector;
|
|
|
|
|
2021-12-16 13:51:22 +08:00
|
|
|
class DBHead : public MMOCR {
|
2021-12-07 10:57:55 +08:00
|
|
|
public:
|
2021-12-16 13:51:22 +08:00
|
|
|
explicit DBHead(const Value& config) : MMOCR(config) {
|
2021-12-07 10:57:55 +08:00
|
|
|
if (config.contains("params")) {
|
|
|
|
auto& params = config["params"];
|
2022-05-31 21:24:09 +08:00
|
|
|
text_repr_type_ = params.value("text_repr_type", text_repr_type_);
|
|
|
|
mask_thr_ = params.value("mask_thr", mask_thr_);
|
|
|
|
min_text_score_ = params.value("min_text_score", min_text_score_);
|
|
|
|
min_text_width_ = params.value("min_text_width", min_text_width_);
|
|
|
|
unclip_ratio_ = params.value("unclip_ratio", unclip_ratio_);
|
|
|
|
max_candidates_ = params.value("max_candidate", max_candidates_);
|
|
|
|
rescale_ = params.value("rescale", rescale_);
|
|
|
|
downsample_ratio_ = params.value("downsample_ratio", downsample_ratio_);
|
|
|
|
}
|
|
|
|
auto platform = Platform(device_.platform_id()).GetPlatformName();
|
|
|
|
auto creator = Registry<DbHeadImpl>::Get().GetCreator(platform);
|
|
|
|
if (!creator) {
|
|
|
|
MMDEPLOY_ERROR("DBHead: implementation for platform \"{}\" not found", platform);
|
|
|
|
throw_exception(eEntryNotFound);
|
2021-12-07 10:57:55 +08:00
|
|
|
}
|
2022-05-31 21:24:09 +08:00
|
|
|
impl_ = creator->Create(nullptr);
|
|
|
|
impl_->Init(stream_);
|
2021-12-07 10:57:55 +08:00
|
|
|
}
|
|
|
|
|
2022-05-31 21:24:09 +08:00
|
|
|
Result<Value> operator()(const Value& _data, const Value& _prob) const {
|
|
|
|
auto conf = _prob["output"].get<Tensor>();
|
2021-12-21 20:16:40 +08:00
|
|
|
if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) {
|
2022-02-24 20:08:44 +08:00
|
|
|
MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(),
|
|
|
|
(int)conf.data_type());
|
2021-12-21 20:16:40 +08:00
|
|
|
return Status(eNotSupported);
|
|
|
|
}
|
|
|
|
|
2022-05-31 21:24:09 +08:00
|
|
|
conf.Squeeze();
|
|
|
|
conf = conf.Slice(0);
|
2021-12-07 10:57:55 +08:00
|
|
|
|
|
|
|
std::vector<std::vector<cv::Point>> contours;
|
2022-05-31 21:24:09 +08:00
|
|
|
std::vector<float> scores;
|
|
|
|
OUTCOME_TRY(impl_->Process(conf, mask_thr_, max_candidates_, contours, scores));
|
|
|
|
|
|
|
|
auto scale_w = 1.f;
|
|
|
|
auto scale_h = 1.f;
|
|
|
|
if (rescale_) {
|
|
|
|
scale_w /= downsample_ratio_ * _data["img_metas"]["scale_factor"][0].get<float>();
|
|
|
|
scale_h /= downsample_ratio_ * _data["img_metas"]["scale_factor"][1].get<float>();
|
2021-12-07 10:57:55 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
TextDetectorOutput output;
|
2022-05-31 21:24:09 +08:00
|
|
|
for (int idx = 0; idx < contours.size(); ++idx) {
|
|
|
|
if (scores[idx] < min_text_score_) {
|
2021-12-07 10:57:55 +08:00
|
|
|
continue;
|
|
|
|
}
|
2022-05-31 21:24:09 +08:00
|
|
|
auto expanded = unclip(contours[idx], unclip_ratio_);
|
|
|
|
if (expanded.empty()) {
|
2021-12-07 10:57:55 +08:00
|
|
|
continue;
|
|
|
|
}
|
2022-05-31 21:24:09 +08:00
|
|
|
auto rect = cv::minAreaRect(expanded);
|
|
|
|
if ((int)rect.size.width <= min_text_width_) {
|
2021-12-07 10:57:55 +08:00
|
|
|
continue;
|
|
|
|
}
|
2022-05-31 21:24:09 +08:00
|
|
|
std::array<cv::Point2f, 4> box_points;
|
|
|
|
rect.points(box_points.data());
|
2021-12-07 10:57:55 +08:00
|
|
|
auto& bbox = output.boxes.emplace_back();
|
|
|
|
for (int i = 0; i < 4; ++i) {
|
2022-05-31 21:24:09 +08:00
|
|
|
// ! performance metrics drops without rounding here
|
|
|
|
bbox[i * 2] = cvRound(box_points[i].x * scale_w);
|
|
|
|
bbox[i * 2 + 1] = cvRound(box_points[i].y * scale_h);
|
2021-12-07 10:57:55 +08:00
|
|
|
}
|
2022-05-31 21:24:09 +08:00
|
|
|
output.scores.push_back(scores[idx]);
|
2021-12-07 10:57:55 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
return to_value(output);
|
|
|
|
}
|
|
|
|
|
|
|
|
static std::vector<cv::Point> unclip(std::vector<cv::Point>& box, float unclip_ratio) {
|
|
|
|
namespace cl = ClipperLib;
|
|
|
|
|
|
|
|
auto area = cv::contourArea(box);
|
|
|
|
auto length = cv::arcLength(box, true);
|
|
|
|
auto distance = area * unclip_ratio / length;
|
|
|
|
|
|
|
|
cl::Path src;
|
|
|
|
transform(begin(box), end(box), back_inserter(src), [](auto p) {
|
|
|
|
return cl::IntPoint{p.x, p.y};
|
|
|
|
});
|
|
|
|
|
|
|
|
cl::ClipperOffset offset;
|
|
|
|
offset.AddPath(src, cl::jtRound, cl::etClosedPolygon);
|
|
|
|
|
|
|
|
std::vector<cl::Path> dst;
|
|
|
|
offset.Execute(dst, distance);
|
|
|
|
if (dst.size() != 1) {
|
|
|
|
return {};
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<cv::Point> ret;
|
|
|
|
transform(begin(dst[0]), end(dst[0]), back_inserter(ret), [](auto p) {
|
|
|
|
return cv::Point{static_cast<int>(p.X), static_cast<int>(p.Y)};
|
|
|
|
});
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string text_repr_type_{"quad"};
|
|
|
|
float mask_thr_{.3};
|
|
|
|
float min_text_score_{.3};
|
|
|
|
int min_text_width_{5};
|
|
|
|
float unclip_ratio_{1.5};
|
|
|
|
int max_candidates_{3000};
|
|
|
|
bool rescale_{true};
|
|
|
|
float downsample_ratio_{1.};
|
2022-05-31 21:24:09 +08:00
|
|
|
|
|
|
|
std::unique_ptr<DbHeadImpl> impl_;
|
2021-12-07 10:57:55 +08:00
|
|
|
};
|
|
|
|
|
2021-12-16 13:51:22 +08:00
|
|
|
REGISTER_CODEBASE_COMPONENT(MMOCR, DBHead);
|
2021-12-07 10:57:55 +08:00
|
|
|
|
2022-05-31 21:24:09 +08:00
|
|
|
} // namespace mmocr
|
|
|
|
|
|
|
|
MMDEPLOY_DEFINE_REGISTRY(mmocr::DbHeadImpl);
|
|
|
|
|
|
|
|
} // namespace mmdeploy
|