// Copyright (c) OpenMMLab. All rights reserved. #include "codebase/mmocr/dbnet.h" #include "codebase/mmocr/cuda/connected_component.h" #include "codebase/mmocr/cuda/utils.h" #include "core/utils/device_utils.h" #include "cuda_runtime.h" #include "device/cuda/cuda_device.h" #include "opencv2/imgcodecs.hpp" #include "opencv2/imgproc.hpp" namespace mmdeploy::mmocr { class DbHeadCudaImpl : public DbHeadImpl { public: void Init(const Stream& stream) override { DbHeadImpl::Init(stream); device_ = stream_.GetDevice(); { CudaDeviceGuard device_guard(device_); cc_.emplace(GetNative<cudaStream_t>(stream_)); } } ~DbHeadCudaImpl() override { CudaDeviceGuard device_guard(device_); cc_.reset(); } Result<void> Process(Tensor score, float mask_thr, int max_candidates, std::vector<std::vector<cv::Point>>& contours, std::vector<float>& scores) override { CudaDeviceGuard device_guard(device_); auto height = static_cast<int>(score.shape(1)); auto width = static_cast<int>(score.shape(2)); Buffer mask(device_, score.size() * sizeof(uint8_t)); auto score_data = score.data<float>(); auto mask_data = GetNative<uint8_t*>(mask); dbnet::Threshold(score_data, height * width, mask_thr, mask_data, GetNative<cudaStream_t>(stream_)); cc_->Resize(height, width); cc_->GetComponents(mask_data, nullptr); std::vector<std::vector<cv::Point>> points; cc_->GetContours(points); std::vector<float> _scores; std::vector<int> _areas; cc_->GetStats(mask_data, score_data, _scores, _areas); if (points.size() > max_candidates) { points.resize(max_candidates); } for (int i = 0; i < points.size(); ++i) { std::vector<cv::Point> hull; cv::convexHull(points[i], hull); if (hull.size() < 4) { continue; } contours.push_back(hull); scores.push_back(_scores[i] / (float)_areas[i]); } return success(); } private: Device device_; std::optional<ConnectedComponents> cc_; }; class DbHeadCudaImplCreator : public ::mmdeploy::Creator<DbHeadImpl> { public: const char* GetName() const override { return "cuda"; } int GetVersion() const override { return 0; } std::unique_ptr<DbHeadImpl> Create(const Value&) override { return std::make_unique<DbHeadCudaImpl>(); } }; REGISTER_MODULE(DbHeadImpl, DbHeadCudaImplCreator); } // namespace mmdeploy::mmocr