// Copyright (c) OpenMMLab. All rights reserved. #include "core/utils/device_utils.h" #include "core/utils/formatter.h" #include "ppl/cv/cuda/resize.h" #include "preprocess/transform/resize.h" using namespace std; namespace mmdeploy { namespace cuda { class ResizeImpl final : public ::mmdeploy::ResizeImpl { public: explicit ResizeImpl(const Value& args) : ::mmdeploy::ResizeImpl(args) { if (arg_.interpolation != "bilinear" && arg_.interpolation != "nearest") { ERROR("{} interpolation is not supported", arg_.interpolation); throw_exception(eNotSupported); } } ~ResizeImpl() override = default; protected: Result ResizeImage(const Tensor& tensor, int dst_h, int dst_w) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); TensorDesc dst_desc{ device_, src_tensor.data_type(), {1, dst_h, dst_w, src_tensor.shape(3)}, src_tensor.name()}; Tensor dst_tensor(dst_desc); auto stream = GetNative(stream_); if (tensor.data_type() == DataType::kINT8) { OUTCOME_TRY(ResizeDispatch(src_tensor, dst_tensor, stream)); } else if (tensor.data_type() == DataType::kFLOAT) { OUTCOME_TRY(ResizeDispatch(src_tensor, dst_tensor, stream)); } else { ERROR("unsupported data type {}", tensor.data_type()); return Status(eNotSupported); } return dst_tensor; } private: template ppl::common::RetCode DispatchImpl(Args&&... args) { #ifdef PPLCV_VERSION_MAJOR if (arg_.interpolation == "bilinear") { return ppl::cv::cuda::Resize(std::forward(args)..., ppl::cv::INTERPOLATION_TYPE_LINEAR); } if (arg_.interpolation == "nearest") { return ppl::cv::cuda::Resize(std::forward(args)..., ppl::cv::INTERPOLATION_TYPE_NEAREST_POINT); } #else #warning "support for ppl.cv < 0.6 is deprecated and will be dropped in the future" if (arg_.interpolation == "bilinear") { return ppl::cv::cuda::ResizeLinear(std::forward(args)...); } if (arg_.interpolation == "nearest") { return ppl::cv::cuda::ResizeNearestPoint(std::forward(args)...); } #endif return ppl::common::RC_UNSUPPORTED; } template Result ResizeDispatch(const Tensor& src, Tensor& dst, cudaStream_t stream) { int h = (int)src.shape(1); int w = (int)src.shape(2); int c = (int)src.shape(3); int dst_h = (int)dst.shape(1); int dst_w = (int)dst.shape(2); ppl::common::RetCode ret = 0; auto input = src.data(); auto output = dst.data(); if (1 == c) { ret = DispatchImpl(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output); } else if (3 == c) { ret = DispatchImpl(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output); } else if (4 == c) { ret = DispatchImpl(stream, h, w, w * c, input, dst_h, dst_w, dst_w * c, output); } else { ERROR("unsupported channels {}", c); return Status(eNotSupported); } return ret == 0 ? success() : Result(Status(eFail)); } }; class ResizeImplCreator : public Creator<::mmdeploy::ResizeImpl> { public: const char* GetName() const override { return "cuda"; } int GetVersion() const override { return 1; } ReturnType Create(const Value& args) override { return make_unique(args); } }; } // namespace cuda } // namespace mmdeploy using ::mmdeploy::ResizeImpl; using ::mmdeploy::cuda::ResizeImplCreator; REGISTER_MODULE(ResizeImpl, ResizeImplCreator);