diff --git a/csrc/preprocess/cpu/CMakeLists.txt b/csrc/preprocess/cpu/CMakeLists.txt index d2a75b10e..219f4da41 100644 --- a/csrc/preprocess/cpu/CMakeLists.txt +++ b/csrc/preprocess/cpu/CMakeLists.txt @@ -9,6 +9,7 @@ set(SRCS collect_impl.cpp crop_impl.cpp image2tensor_impl.cpp + default_format_bundle_impl.cpp load_impl.cpp normalize_impl.cpp pad_impl.cpp diff --git a/csrc/preprocess/cpu/default_format_bundle_impl.cpp b/csrc/preprocess/cpu/default_format_bundle_impl.cpp new file mode 100644 index 000000000..efee3cc47 --- /dev/null +++ b/csrc/preprocess/cpu/default_format_bundle_impl.cpp @@ -0,0 +1,58 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "core/utils/device_utils.h" +#include "opencv_utils.h" +#include "preprocess/transform/default_format_bundle.h" + +namespace mmdeploy { +namespace cpu { + +class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl { + public: + explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {} + + protected: + Result ToFloat32(const Tensor& tensor, const bool& img_to_float) override { + OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + auto data_type = src_tensor.desc().data_type; + + if (img_to_float && data_type == DataType::kINT8) { + auto cvmat = Tensor2CVMat(src_tensor); + cvmat.convertTo(cvmat, CV_32FC(cvmat.channels())); + auto dst_tensor = CVMat2Tensor(cvmat); + return dst_tensor; + } + return src_tensor; + } + + Result HWC2CHW(const Tensor& tensor) override { + OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + auto shape = src_tensor.shape(); + int height = shape[1]; + int width = shape[2]; + int channels = shape[3]; + + auto dst_mat = Transpose(Tensor2CVMat(src_tensor)); + + auto dst_tensor = CVMat2Tensor(dst_mat); + dst_tensor.Reshape({1, channels, height, width}); + + return dst_tensor; + } +}; + +class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> { + public: + const char* GetName() const override { return "cpu"; } + int GetVersion() const override { return 1; } + ReturnType Create(const Value& args) override { + return std::make_unique(args); + } +}; + +} // namespace cpu +} // namespace mmdeploy + +using mmdeploy::DefaultFormatBundleImpl; +using mmdeploy::cpu::DefaultFormatBundleImplCreator; +REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator); diff --git a/csrc/preprocess/cuda/CMakeLists.txt b/csrc/preprocess/cuda/CMakeLists.txt index 76caeb214..be1c7b33d 100644 --- a/csrc/preprocess/cuda/CMakeLists.txt +++ b/csrc/preprocess/cuda/CMakeLists.txt @@ -14,6 +14,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake) set(SRCS crop_impl.cpp image2tensor_impl.cpp + default_format_bundle_impl.cpp load_impl.cpp normalize_impl.cpp pad_impl.cpp diff --git a/csrc/preprocess/cuda/default_format_bundle_impl.cpp b/csrc/preprocess/cuda/default_format_bundle_impl.cpp new file mode 100644 index 000000000..2091d4ac8 --- /dev/null +++ b/csrc/preprocess/cuda/default_format_bundle_impl.cpp @@ -0,0 +1,87 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include + +#include "core/utils/device_utils.h" +#include "preprocess/transform/default_format_bundle.h" + +namespace mmdeploy { +namespace cuda { + +template +void CastToFloat(const uint8_t* src, int height, int width, float* dst, cudaStream_t stream); + +template +void Transpose(const T* src, int height, int width, int channels, T* dst, cudaStream_t stream); + +class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl { + public: + explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {} + + protected: + Result ToFloat32(const Tensor& tensor, const bool& img_to_float) override { + OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + auto data_type = src_tensor.data_type(); + auto h = tensor.shape(1); + auto w = tensor.shape(2); + auto c = tensor.shape(3); + auto stream = ::mmdeploy::GetNative(stream_); + + if (img_to_float && data_type == DataType::kINT8) { + TensorDesc desc{device_, DataType::kFLOAT, tensor.shape(), ""}; + Tensor dst_tensor{desc}; + if (c == 3) { + CastToFloat<3>(src_tensor.data(), h, w, dst_tensor.data(), stream); + } else if (c == 1) { + CastToFloat<1>(src_tensor.data(), h, w, dst_tensor.data(), stream); + } else { + MMDEPLOY_ERROR("channel num: unsupported channel num {}", c); + return Status(eNotSupported); + } + return dst_tensor; + } + return src_tensor; + } + + Result HWC2CHW(const Tensor& tensor) override { + OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_)); + auto h = tensor.shape(1); + auto w = tensor.shape(2); + auto c = tensor.shape(3); + auto hw = h * w; + + Tensor dst_tensor(src_tensor.desc()); + dst_tensor.Reshape({1, c, h, w}); + + auto stream = ::mmdeploy::GetNative(stream_); + + if (DataType::kINT8 == tensor.data_type()) { + auto input = src_tensor.data(); + auto output = dst_tensor.data(); + Transpose(input, (int)h, (int)w, (int)c, output, stream); + } else if (DataType::kFLOAT == tensor.data_type()) { + auto input = src_tensor.data(); + auto output = dst_tensor.data(); + Transpose(input, (int)h, (int)w, (int)c, output, stream); + } else { + assert(0); + } + return dst_tensor; + } +}; + +class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> { + public: + const char* GetName() const override { return "cuda"; } + int GetVersion() const override { return 1; } + ReturnType Create(const Value& cfg) override { + return std::make_unique(cfg); + } +}; + +} // namespace cuda +} // namespace mmdeploy + +using ::mmdeploy::DefaultFormatBundleImpl; +using ::mmdeploy::cuda::DefaultFormatBundleImplCreator; +REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator); diff --git a/csrc/preprocess/transform/CMakeLists.txt b/csrc/preprocess/transform/CMakeLists.txt index 8e13a67ae..22100838e 100644 --- a/csrc/preprocess/transform/CMakeLists.txt +++ b/csrc/preprocess/transform/CMakeLists.txt @@ -9,6 +9,7 @@ set(SRCS compose.cpp crop.cpp image2tensor.cpp + default_format_bundle.cpp load.cpp normalize.cpp pad.cpp diff --git a/csrc/preprocess/transform/collect.h b/csrc/preprocess/transform/collect.h index 327c5191e..d374de0e8 100644 --- a/csrc/preprocess/transform/collect.h +++ b/csrc/preprocess/transform/collect.h @@ -9,7 +9,7 @@ namespace mmdeploy { class MMDEPLOY_API CollectImpl : public Module { public: explicit CollectImpl(const Value& args); - ~CollectImpl() = default; + ~CollectImpl() override = default; Result Process(const Value& input) override; @@ -27,7 +27,7 @@ class MMDEPLOY_API CollectImpl : public Module { class MMDEPLOY_API Collect : public Transform { public: explicit Collect(const Value& args, int version = 0); - ~Collect() = default; + ~Collect() override = default; Result Process(const Value& input) override; diff --git a/csrc/preprocess/transform/crop.h b/csrc/preprocess/transform/crop.h index 76c567271..1f96b19b2 100644 --- a/csrc/preprocess/transform/crop.h +++ b/csrc/preprocess/transform/crop.h @@ -13,7 +13,7 @@ namespace mmdeploy { class MMDEPLOY_API CenterCropImpl : public TransformImpl { public: explicit CenterCropImpl(const Value& args); - ~CenterCropImpl() = default; + ~CenterCropImpl() override = default; Result Process(const Value& input) override; @@ -34,7 +34,7 @@ class MMDEPLOY_API CenterCropImpl : public TransformImpl { class MMDEPLOY_API CenterCrop : public Transform { public: explicit CenterCrop(const Value& args, int version = 0); - ~CenterCrop() = default; + ~CenterCrop() override = default; Result Process(const Value& input) override { return impl_->Process(input); } diff --git a/csrc/preprocess/transform/default_format_bundle.cpp b/csrc/preprocess/transform/default_format_bundle.cpp new file mode 100644 index 000000000..7dbcbfa73 --- /dev/null +++ b/csrc/preprocess/transform/default_format_bundle.cpp @@ -0,0 +1,75 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include "default_format_bundle.h" + +#include + +#include "archive/json_archive.h" +#include "core/tensor.h" + +namespace mmdeploy { + +DefaultFormatBundleImpl::DefaultFormatBundleImpl(const Value& args) : TransformImpl(args) { + if (args.contains("img_to_float") && args["img_to_float"].is_boolean()) { + arg_.img_to_float = args["img_to_float"].get(); + } +} + +Result DefaultFormatBundleImpl::Process(const Value& input) { + MMDEPLOY_DEBUG("DefaultFormatBundle input: {}", to_json(input).dump(2)); + Value output = input; + if (input.contains("img")) { + Tensor in_tensor = input["img"].get(); + OUTCOME_TRY(output["img"], ToFloat32(in_tensor, arg_.img_to_float)); + + Tensor tensor = output["img"].get(); + // set default meta keys + if (!output.contains("pad_shape")) { + for (auto v : tensor.shape()) { + output["pad_shape"].push_back(v); + } + } + if (!output.contains("scale_factor")) { + output["scale_factor"].push_back(1.0); + } + if (!output.contains("img_norm_cfg")) { + int channel = tensor.shape()[3]; + for (int i = 0; i < channel; i++) { + output["img_norm_cfg"]["mean"].push_back(0.0); + output["img_norm_cfg"]["std"].push_back(1.0); + } + output["img_norm_cfg"]["to_rgb"] = false; + } + + // transpose + OUTCOME_TRY(output["img"], HWC2CHW(tensor)); + } + + MMDEPLOY_DEBUG("DefaultFormatBundle output: {}", to_json(output).dump(2)); + return output; +} + +DefaultFormatBundle::DefaultFormatBundle(const Value& args, int version) : Transform(args) { + auto impl_creator = + Registry::Get().GetCreator(specified_platform_, version); + if (nullptr == impl_creator) { + MMDEPLOY_ERROR("'DefaultFormatBundle' is not supported on '{}' platform", specified_platform_); + throw std::domain_error("'DefaultFormatBundle' is not supported on specified platform"); + } + impl_ = impl_creator->Create(args); +} + +class DefaultFormatBundleCreator : public Creator { + public: + const char* GetName() const override { return "DefaultFormatBundle"; } + int GetVersion() const override { return version_; } + ReturnType Create(const Value& args) override { + return std::make_unique(args, version_); + } + + private: + int version_{1}; +}; +REGISTER_MODULE(Transform, DefaultFormatBundleCreator); +MMDEPLOY_DEFINE_REGISTRY(DefaultFormatBundleImpl); +} // namespace mmdeploy diff --git a/csrc/preprocess/transform/default_format_bundle.h b/csrc/preprocess/transform/default_format_bundle.h new file mode 100644 index 000000000..708b737f1 --- /dev/null +++ b/csrc/preprocess/transform/default_format_bundle.h @@ -0,0 +1,49 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#ifndef MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H +#define MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H + +#include "core/tensor.h" +#include "transform.h" + +namespace mmdeploy { +/** + * It simplifies the pipeline of formatting common fields + */ +class MMDEPLOY_API DefaultFormatBundleImpl : public TransformImpl { + public: + DefaultFormatBundleImpl(const Value& args); + ~DefaultFormatBundleImpl() override = default; + + Result Process(const Value& input) override; + + protected: + virtual Result ToFloat32(const Tensor& tensor, const bool& img_to_float) = 0; + virtual Result HWC2CHW(const Tensor& tensor) = 0; + + protected: + struct default_format_bundle_arg_t { + bool img_to_float = true; + }; + using ArgType = struct default_format_bundle_arg_t; + + protected: + ArgType arg_; +}; + +class MMDEPLOY_API DefaultFormatBundle : public Transform { + public: + explicit DefaultFormatBundle(const Value& args, int version = 0); + ~DefaultFormatBundle() override = default; + + Result Process(const Value& input) override { return impl_->Process(input); } + + private: + std::unique_ptr impl_; +}; + +MMDEPLOY_DECLARE_REGISTRY(DefaultFormatBundleImpl); + +} // namespace mmdeploy + +#endif // MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H diff --git a/csrc/preprocess/transform/image2tensor.h b/csrc/preprocess/transform/image2tensor.h index 49eefd9f4..67acb2c67 100644 --- a/csrc/preprocess/transform/image2tensor.h +++ b/csrc/preprocess/transform/image2tensor.h @@ -17,7 +17,7 @@ namespace mmdeploy { class MMDEPLOY_API ImageToTensorImpl : public TransformImpl { public: ImageToTensorImpl(const Value& args); - ~ImageToTensorImpl() = default; + ~ImageToTensorImpl() override = default; Result Process(const Value& input) override; @@ -37,7 +37,7 @@ class MMDEPLOY_API ImageToTensorImpl : public TransformImpl { class MMDEPLOY_API ImageToTensor : public Transform { public: explicit ImageToTensor(const Value& args, int version = 0); - ~ImageToTensor() = default; + ~ImageToTensor() override = default; Result Process(const Value& input) override { return impl_->Process(input); } diff --git a/csrc/preprocess/transform/load.h b/csrc/preprocess/transform/load.h index a05d4c136..66f53d8f3 100644 --- a/csrc/preprocess/transform/load.h +++ b/csrc/preprocess/transform/load.h @@ -11,7 +11,7 @@ namespace mmdeploy { class MMDEPLOY_API PrepareImageImpl : public TransformImpl { public: explicit PrepareImageImpl(const Value& args); - ~PrepareImageImpl() = default; + ~PrepareImageImpl() override = default; Result Process(const Value& input) override; @@ -32,7 +32,7 @@ class MMDEPLOY_API PrepareImageImpl : public TransformImpl { class MMDEPLOY_API PrepareImage : public Transform { public: explicit PrepareImage(const Value& args, int version = 0); - ~PrepareImage() = default; + ~PrepareImage() override = default; Result Process(const Value& input) override { return impl_->Process(input); } diff --git a/csrc/preprocess/transform/normalize.h b/csrc/preprocess/transform/normalize.h index fef8fd17c..f06adddaf 100644 --- a/csrc/preprocess/transform/normalize.h +++ b/csrc/preprocess/transform/normalize.h @@ -11,7 +11,7 @@ namespace mmdeploy { class MMDEPLOY_API NormalizeImpl : public TransformImpl { public: explicit NormalizeImpl(const Value& args); - ~NormalizeImpl() = default; + ~NormalizeImpl() override = default; Result Process(const Value& input) override; @@ -31,7 +31,7 @@ class MMDEPLOY_API NormalizeImpl : public TransformImpl { class MMDEPLOY_API Normalize : public Transform { public: explicit Normalize(const Value& args, int version = 0); - ~Normalize() = default; + ~Normalize() override = default; Result Process(const Value& input) override { return impl_->Process(input); } diff --git a/mmdeploy/utils/export_info.py b/mmdeploy/utils/export_info.py index 8c5466f89..bd9d3e36e 100644 --- a/mmdeploy/utils/export_info.py +++ b/mmdeploy/utils/export_info.py @@ -191,8 +191,6 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config): and 'RescaleToZeroOne' not in item['type'] ] for i, transform in enumerate(transforms): - if transform['type'] == 'DefaultFormatBundle': - transforms[i] = dict(type='ImageToTensor', keys=['img']) if 'keys' in transform and transform['keys'] == ['lq']: transform['keys'] = ['img'] if 'key' in transform and transform['key'] == 'lq': diff --git a/tests/test_csrc/preprocess/test_default_format_bundle.cpp b/tests/test_csrc/preprocess/test_default_format_bundle.cpp new file mode 100644 index 000000000..84c669b69 --- /dev/null +++ b/tests/test_csrc/preprocess/test_default_format_bundle.cpp @@ -0,0 +1,68 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "catch.hpp" +#include "core/tensor.h" +#include "core/utils/device_utils.h" +#include "opencv_utils.h" +#include "preprocess/transform/transform.h" +#include "test_resource.h" +#include "test_utils.h" + +using namespace mmdeploy; +using namespace mmdeploy::test; +using namespace std; + +void TestDefaultFormatBundle(const Value& cfg, const cv::Mat& mat) { + auto gResource = MMDeployTestResources::Get(); + for (auto const& device_name : gResource.device_names()) { + Device device{device_name.c_str()}; + Stream stream{device}; + auto transform = CreateTransform(cfg, device, stream); + REQUIRE(transform != nullptr); + + vector channel_mats(mat.channels()); + for (auto i = 0; i < mat.channels(); ++i) { + cv::extractChannel(mat, channel_mats[i], i); + } + + auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}}); + REQUIRE(!res.has_error()); + auto res_tensor = res.value()["img"].get(); + REQUIRE(res_tensor.device() == device); + auto shape = res_tensor.desc().shape; + REQUIRE(shape == std::vector{1, mat.channels(), mat.rows, mat.cols}); + + const Device kHost{"cpu"}; + auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream); + REQUIRE(stream.Wait()); + + // mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w} + // compare each channel between `res_tensor` and `mat` + auto step = shape[2] * shape[3] * mat.elemSize1(); + auto data = host_tensor.value().data(); + for (auto i = 0; i < mat.channels(); ++i) { + cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data}; + REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat)); + data += step; + } + } +} + +TEST_CASE("transform DefaultFormatBundle", "[img2tensor]") { + auto gResource = MMDeployTestResources::Get(); + auto img_list = gResource.LocateImageResources("transform"); + REQUIRE(!img_list.empty()); + + auto img_path = img_list.front(); + cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR); + cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE); + cv::Mat bgr_float_mat; + cv::Mat gray_float_mat; + bgr_mat.convertTo(bgr_float_mat, CV_32FC3); + gray_mat.convertTo(gray_float_mat, CV_32FC1); + + Value cfg{{"type", "DefaultFormatBundle"}, {"keys", {"img"}}}; + vector mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat}; + for (auto& mat : mats) { + TestDefaultFormatBundle(cfg, mat); + } +}