// Copyright (c) OpenMMLab. All rights reserved. #include #include "core/utils/device_utils.h" #include "preprocess/transform/default_format_bundle.h" namespace mmdeploy { namespace cuda { template void CastToFloat(const uint8_t* src, int height, int width, float* dst, cudaStream_t stream); template void Transpose(const T* src, int height, int width, int channels, T* dst, cudaStream_t stream); class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl { public: explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {} protected: Result ToFloat32(const Tensor& tensor, const bool& img_to_float) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); auto data_type = src_tensor.data_type(); auto h = tensor.shape(1); auto w = tensor.shape(2); auto c = tensor.shape(3); auto stream = ::mmdeploy::GetNative(stream_); if (img_to_float && data_type == DataType::kINT8) { TensorDesc desc{device_, DataType::kFLOAT, tensor.shape(), ""}; Tensor dst_tensor{desc}; if (c == 3) { CastToFloat<3>(src_tensor.data(), h, w, dst_tensor.data(), stream); } else if (c == 1) { CastToFloat<1>(src_tensor.data(), h, w, dst_tensor.data(), stream); } else { MMDEPLOY_ERROR("channel num: unsupported channel num {}", c); return Status(eNotSupported); } return dst_tensor; } return src_tensor; } Result HWC2CHW(const Tensor& tensor) override { OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); SyncOnScopeExit sync(stream_, src_tensor.buffer() != tensor.buffer(), src_tensor); auto h = tensor.shape(1); auto w = tensor.shape(2); auto c = tensor.shape(3); auto hw = h * w; Tensor dst_tensor(src_tensor.desc()); dst_tensor.Reshape({1, c, h, w}); auto stream = ::mmdeploy::GetNative(stream_); if (DataType::kINT8 == tensor.data_type()) { auto input = src_tensor.data(); auto output = dst_tensor.data(); Transpose(input, (int)h, (int)w, (int)c, output, stream); } else if (DataType::kFLOAT == tensor.data_type()) { auto input = src_tensor.data(); auto output = dst_tensor.data(); Transpose(input, (int)h, (int)w, (int)c, output, stream); } else { assert(0); } return dst_tensor; } }; class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> { public: const char* GetName() const override { return "cuda"; } int GetVersion() const override { return 1; } ReturnType Create(const Value& cfg) override { return std::make_unique(cfg); } }; } // namespace cuda } // namespace mmdeploy using ::mmdeploy::DefaultFormatBundleImpl; using ::mmdeploy::cuda::DefaultFormatBundleImplCreator; REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);