// Copyright (c) OpenMMLab. All rights reserved. #include "codebase/mmocr/panet.h" #include "codebase/mmocr/cuda/connected_component.h" #include "codebase/mmocr/cuda/utils.h" #include "device/cuda/cuda_device.h" namespace mmdeploy::mmocr { class PaHeadCudaImpl : public PaHeadImpl { public: void Init(const Stream& stream) override { PaHeadImpl::Init(stream); device_ = stream.GetDevice(); { CudaDeviceGuard device_guard(device_); cc_.emplace(GetNative(stream_)); } } ~PaHeadCudaImpl() override { CudaDeviceGuard device_guard(device_); cc_.reset(); } Result Process(Tensor text_pred, // Tensor kernel_pred, // Tensor embed_pred, // float min_text_confidence, // float min_kernel_confidence, // cv::Mat_& text_score, // cv::Mat_& text, // cv::Mat_& kernel, // cv::Mat_& label, // cv::Mat_& embed, // int& region_num) override { CudaDeviceGuard device_guard(device_); auto height = static_cast(text_pred.shape(1)); auto width = static_cast(text_pred.shape(2)); Buffer text_buf(device_, height * width); Buffer text_score_buf(device_, height * width * sizeof(float)); Buffer kernel_buf(device_, height * width); auto text_buf_data = GetNative(text_buf); auto text_score_buf_data = GetNative(text_score_buf); auto kernel_buf_data = GetNative(kernel_buf); auto stream = GetNative(stream_); panet::ProcessMasks(text_pred.data(), // kernel_pred.data(), // min_text_confidence, // min_kernel_confidence, // height * width, // text_buf_data, // kernel_buf_data, // text_score_buf_data, // stream); auto n_embed_channels = embed_pred.shape(0); Buffer embed_buf(device_, embed_pred.byte_size()); panet::Transpose(embed_pred.data(), // n_embed_channels, // height * width, // GetNative(embed_buf), // stream); label = cv::Mat_(height, width); cc_->Resize(height, width); region_num = cc_->GetComponents(kernel_buf_data, label.ptr()) + 1; text_score = cv::Mat_(label.size()); OUTCOME_TRY(stream_.Copy(text_score_buf, text_score.data)); text = cv::Mat_(label.size()); OUTCOME_TRY(stream_.Copy(text_buf, text.data)); kernel = cv::Mat_(label.size()); OUTCOME_TRY(stream_.Copy(kernel_buf, kernel.data)); embed = cv::Mat_(height * width, n_embed_channels); OUTCOME_TRY(stream_.Copy(embed_buf, embed.data)); OUTCOME_TRY(stream_.Wait()); return success(); } private: Device device_; std::optional cc_; }; class PaHeadCudaImplCreator : public ::mmdeploy::Creator { public: const char* GetName() const override { return "cuda"; } int GetVersion() const override { return 0; } std::unique_ptr Create(const Value&) override { return std::make_unique(); } }; REGISTER_MODULE(PaHeadImpl, PaHeadCudaImplCreator); } // namespace mmdeploy::mmocr