mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
Add DefaultFormatBundle (#208)
* keep DefaultFormatBundle * add DefaultFormatBundle * add condition * resolve comments * remove useless * add override
This commit is contained in:
parent
776659a6ce
commit
ea54f3b2fd
@ -9,6 +9,7 @@ set(SRCS
|
||||
collect_impl.cpp
|
||||
crop_impl.cpp
|
||||
image2tensor_impl.cpp
|
||||
default_format_bundle_impl.cpp
|
||||
load_impl.cpp
|
||||
normalize_impl.cpp
|
||||
pad_impl.cpp
|
||||
|
58
csrc/preprocess/cpu/default_format_bundle_impl.cpp
Normal file
58
csrc/preprocess/cpu/default_format_bundle_impl.cpp
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "core/utils/device_utils.h"
|
||||
#include "opencv_utils.h"
|
||||
#include "preprocess/transform/default_format_bundle.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace cpu {
|
||||
|
||||
class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl {
|
||||
public:
|
||||
explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
auto data_type = src_tensor.desc().data_type;
|
||||
|
||||
if (img_to_float && data_type == DataType::kINT8) {
|
||||
auto cvmat = Tensor2CVMat(src_tensor);
|
||||
cvmat.convertTo(cvmat, CV_32FC(cvmat.channels()));
|
||||
auto dst_tensor = CVMat2Tensor(cvmat);
|
||||
return dst_tensor;
|
||||
}
|
||||
return src_tensor;
|
||||
}
|
||||
|
||||
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
auto shape = src_tensor.shape();
|
||||
int height = shape[1];
|
||||
int width = shape[2];
|
||||
int channels = shape[3];
|
||||
|
||||
auto dst_mat = Transpose(Tensor2CVMat(src_tensor));
|
||||
|
||||
auto dst_tensor = CVMat2Tensor(dst_mat);
|
||||
dst_tensor.Reshape({1, channels, height, width});
|
||||
|
||||
return dst_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "cpu"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override {
|
||||
return std::make_unique<DefaultFormatBundleImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::DefaultFormatBundleImpl;
|
||||
using mmdeploy::cpu::DefaultFormatBundleImplCreator;
|
||||
REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);
|
@ -14,6 +14,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
|
||||
set(SRCS
|
||||
crop_impl.cpp
|
||||
image2tensor_impl.cpp
|
||||
default_format_bundle_impl.cpp
|
||||
load_impl.cpp
|
||||
normalize_impl.cpp
|
||||
pad_impl.cpp
|
||||
|
87
csrc/preprocess/cuda/default_format_bundle_impl.cpp
Normal file
87
csrc/preprocess/cuda/default_format_bundle_impl.cpp
Normal file
@ -0,0 +1,87 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "core/utils/device_utils.h"
|
||||
#include "preprocess/transform/default_format_bundle.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace cuda {
|
||||
|
||||
template <int channels>
|
||||
void CastToFloat(const uint8_t* src, int height, int width, float* dst, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void Transpose(const T* src, int height, int width, int channels, T* dst, cudaStream_t stream);
|
||||
|
||||
class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl {
|
||||
public:
|
||||
explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
auto data_type = src_tensor.data_type();
|
||||
auto h = tensor.shape(1);
|
||||
auto w = tensor.shape(2);
|
||||
auto c = tensor.shape(3);
|
||||
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
|
||||
|
||||
if (img_to_float && data_type == DataType::kINT8) {
|
||||
TensorDesc desc{device_, DataType::kFLOAT, tensor.shape(), ""};
|
||||
Tensor dst_tensor{desc};
|
||||
if (c == 3) {
|
||||
CastToFloat<3>(src_tensor.data<uint8_t>(), h, w, dst_tensor.data<float>(), stream);
|
||||
} else if (c == 1) {
|
||||
CastToFloat<1>(src_tensor.data<uint8_t>(), h, w, dst_tensor.data<float>(), stream);
|
||||
} else {
|
||||
MMDEPLOY_ERROR("channel num: unsupported channel num {}", c);
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
return dst_tensor;
|
||||
}
|
||||
return src_tensor;
|
||||
}
|
||||
|
||||
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
|
||||
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
|
||||
auto h = tensor.shape(1);
|
||||
auto w = tensor.shape(2);
|
||||
auto c = tensor.shape(3);
|
||||
auto hw = h * w;
|
||||
|
||||
Tensor dst_tensor(src_tensor.desc());
|
||||
dst_tensor.Reshape({1, c, h, w});
|
||||
|
||||
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
|
||||
|
||||
if (DataType::kINT8 == tensor.data_type()) {
|
||||
auto input = src_tensor.data<uint8_t>();
|
||||
auto output = dst_tensor.data<uint8_t>();
|
||||
Transpose(input, (int)h, (int)w, (int)c, output, stream);
|
||||
} else if (DataType::kFLOAT == tensor.data_type()) {
|
||||
auto input = src_tensor.data<float>();
|
||||
auto output = dst_tensor.data<float>();
|
||||
Transpose(input, (int)h, (int)w, (int)c, output, stream);
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
return dst_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "cuda"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& cfg) override {
|
||||
return std::make_unique<DefaultFormatBundleImpl>(cfg);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace mmdeploy
|
||||
|
||||
using ::mmdeploy::DefaultFormatBundleImpl;
|
||||
using ::mmdeploy::cuda::DefaultFormatBundleImplCreator;
|
||||
REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);
|
@ -9,6 +9,7 @@ set(SRCS
|
||||
compose.cpp
|
||||
crop.cpp
|
||||
image2tensor.cpp
|
||||
default_format_bundle.cpp
|
||||
load.cpp
|
||||
normalize.cpp
|
||||
pad.cpp
|
||||
|
@ -9,7 +9,7 @@ namespace mmdeploy {
|
||||
class MMDEPLOY_API CollectImpl : public Module {
|
||||
public:
|
||||
explicit CollectImpl(const Value& args);
|
||||
~CollectImpl() = default;
|
||||
~CollectImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
@ -27,7 +27,7 @@ class MMDEPLOY_API CollectImpl : public Module {
|
||||
class MMDEPLOY_API Collect : public Transform {
|
||||
public:
|
||||
explicit Collect(const Value& args, int version = 0);
|
||||
~Collect() = default;
|
||||
~Collect() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
|
@ -13,7 +13,7 @@ namespace mmdeploy {
|
||||
class MMDEPLOY_API CenterCropImpl : public TransformImpl {
|
||||
public:
|
||||
explicit CenterCropImpl(const Value& args);
|
||||
~CenterCropImpl() = default;
|
||||
~CenterCropImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
@ -34,7 +34,7 @@ class MMDEPLOY_API CenterCropImpl : public TransformImpl {
|
||||
class MMDEPLOY_API CenterCrop : public Transform {
|
||||
public:
|
||||
explicit CenterCrop(const Value& args, int version = 0);
|
||||
~CenterCrop() = default;
|
||||
~CenterCrop() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
|
||||
|
||||
|
75
csrc/preprocess/transform/default_format_bundle.cpp
Normal file
75
csrc/preprocess/transform/default_format_bundle.cpp
Normal file
@ -0,0 +1,75 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "default_format_bundle.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "archive/json_archive.h"
|
||||
#include "core/tensor.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
DefaultFormatBundleImpl::DefaultFormatBundleImpl(const Value& args) : TransformImpl(args) {
|
||||
if (args.contains("img_to_float") && args["img_to_float"].is_boolean()) {
|
||||
arg_.img_to_float = args["img_to_float"].get<bool>();
|
||||
}
|
||||
}
|
||||
|
||||
Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
|
||||
MMDEPLOY_DEBUG("DefaultFormatBundle input: {}", to_json(input).dump(2));
|
||||
Value output = input;
|
||||
if (input.contains("img")) {
|
||||
Tensor in_tensor = input["img"].get<Tensor>();
|
||||
OUTCOME_TRY(output["img"], ToFloat32(in_tensor, arg_.img_to_float));
|
||||
|
||||
Tensor tensor = output["img"].get<Tensor>();
|
||||
// set default meta keys
|
||||
if (!output.contains("pad_shape")) {
|
||||
for (auto v : tensor.shape()) {
|
||||
output["pad_shape"].push_back(v);
|
||||
}
|
||||
}
|
||||
if (!output.contains("scale_factor")) {
|
||||
output["scale_factor"].push_back(1.0);
|
||||
}
|
||||
if (!output.contains("img_norm_cfg")) {
|
||||
int channel = tensor.shape()[3];
|
||||
for (int i = 0; i < channel; i++) {
|
||||
output["img_norm_cfg"]["mean"].push_back(0.0);
|
||||
output["img_norm_cfg"]["std"].push_back(1.0);
|
||||
}
|
||||
output["img_norm_cfg"]["to_rgb"] = false;
|
||||
}
|
||||
|
||||
// transpose
|
||||
OUTCOME_TRY(output["img"], HWC2CHW(tensor));
|
||||
}
|
||||
|
||||
MMDEPLOY_DEBUG("DefaultFormatBundle output: {}", to_json(output).dump(2));
|
||||
return output;
|
||||
}
|
||||
|
||||
DefaultFormatBundle::DefaultFormatBundle(const Value& args, int version) : Transform(args) {
|
||||
auto impl_creator =
|
||||
Registry<DefaultFormatBundleImpl>::Get().GetCreator(specified_platform_, version);
|
||||
if (nullptr == impl_creator) {
|
||||
MMDEPLOY_ERROR("'DefaultFormatBundle' is not supported on '{}' platform", specified_platform_);
|
||||
throw std::domain_error("'DefaultFormatBundle' is not supported on specified platform");
|
||||
}
|
||||
impl_ = impl_creator->Create(args);
|
||||
}
|
||||
|
||||
class DefaultFormatBundleCreator : public Creator<Transform> {
|
||||
public:
|
||||
const char* GetName() const override { return "DefaultFormatBundle"; }
|
||||
int GetVersion() const override { return version_; }
|
||||
ReturnType Create(const Value& args) override {
|
||||
return std::make_unique<DefaultFormatBundle>(args, version_);
|
||||
}
|
||||
|
||||
private:
|
||||
int version_{1};
|
||||
};
|
||||
REGISTER_MODULE(Transform, DefaultFormatBundleCreator);
|
||||
MMDEPLOY_DEFINE_REGISTRY(DefaultFormatBundleImpl);
|
||||
} // namespace mmdeploy
|
49
csrc/preprocess/transform/default_format_bundle.h
Normal file
49
csrc/preprocess/transform/default_format_bundle.h
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H
|
||||
#define MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H
|
||||
|
||||
#include "core/tensor.h"
|
||||
#include "transform.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
/**
|
||||
* It simplifies the pipeline of formatting common fields
|
||||
*/
|
||||
class MMDEPLOY_API DefaultFormatBundleImpl : public TransformImpl {
|
||||
public:
|
||||
DefaultFormatBundleImpl(const Value& args);
|
||||
~DefaultFormatBundleImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
protected:
|
||||
virtual Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) = 0;
|
||||
virtual Result<Tensor> HWC2CHW(const Tensor& tensor) = 0;
|
||||
|
||||
protected:
|
||||
struct default_format_bundle_arg_t {
|
||||
bool img_to_float = true;
|
||||
};
|
||||
using ArgType = struct default_format_bundle_arg_t;
|
||||
|
||||
protected:
|
||||
ArgType arg_;
|
||||
};
|
||||
|
||||
class MMDEPLOY_API DefaultFormatBundle : public Transform {
|
||||
public:
|
||||
explicit DefaultFormatBundle(const Value& args, int version = 0);
|
||||
~DefaultFormatBundle() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<DefaultFormatBundleImpl> impl_;
|
||||
};
|
||||
|
||||
MMDEPLOY_DECLARE_REGISTRY(DefaultFormatBundleImpl);
|
||||
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif // MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H
|
@ -17,7 +17,7 @@ namespace mmdeploy {
|
||||
class MMDEPLOY_API ImageToTensorImpl : public TransformImpl {
|
||||
public:
|
||||
ImageToTensorImpl(const Value& args);
|
||||
~ImageToTensorImpl() = default;
|
||||
~ImageToTensorImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
@ -37,7 +37,7 @@ class MMDEPLOY_API ImageToTensorImpl : public TransformImpl {
|
||||
class MMDEPLOY_API ImageToTensor : public Transform {
|
||||
public:
|
||||
explicit ImageToTensor(const Value& args, int version = 0);
|
||||
~ImageToTensor() = default;
|
||||
~ImageToTensor() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
|
||||
|
||||
|
@ -11,7 +11,7 @@ namespace mmdeploy {
|
||||
class MMDEPLOY_API PrepareImageImpl : public TransformImpl {
|
||||
public:
|
||||
explicit PrepareImageImpl(const Value& args);
|
||||
~PrepareImageImpl() = default;
|
||||
~PrepareImageImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
@ -32,7 +32,7 @@ class MMDEPLOY_API PrepareImageImpl : public TransformImpl {
|
||||
class MMDEPLOY_API PrepareImage : public Transform {
|
||||
public:
|
||||
explicit PrepareImage(const Value& args, int version = 0);
|
||||
~PrepareImage() = default;
|
||||
~PrepareImage() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
|
||||
|
||||
|
@ -11,7 +11,7 @@ namespace mmdeploy {
|
||||
class MMDEPLOY_API NormalizeImpl : public TransformImpl {
|
||||
public:
|
||||
explicit NormalizeImpl(const Value& args);
|
||||
~NormalizeImpl() = default;
|
||||
~NormalizeImpl() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override;
|
||||
|
||||
@ -31,7 +31,7 @@ class MMDEPLOY_API NormalizeImpl : public TransformImpl {
|
||||
class MMDEPLOY_API Normalize : public Transform {
|
||||
public:
|
||||
explicit Normalize(const Value& args, int version = 0);
|
||||
~Normalize() = default;
|
||||
~Normalize() override = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
|
||||
|
||||
|
@ -191,8 +191,6 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
|
||||
and 'RescaleToZeroOne' not in item['type']
|
||||
]
|
||||
for i, transform in enumerate(transforms):
|
||||
if transform['type'] == 'DefaultFormatBundle':
|
||||
transforms[i] = dict(type='ImageToTensor', keys=['img'])
|
||||
if 'keys' in transform and transform['keys'] == ['lq']:
|
||||
transform['keys'] = ['img']
|
||||
if 'key' in transform and transform['key'] == 'lq':
|
||||
|
68
tests/test_csrc/preprocess/test_default_format_bundle.cpp
Normal file
68
tests/test_csrc/preprocess/test_default_format_bundle.cpp
Normal file
@ -0,0 +1,68 @@
|
||||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "catch.hpp"
|
||||
#include "core/tensor.h"
|
||||
#include "core/utils/device_utils.h"
|
||||
#include "opencv_utils.h"
|
||||
#include "preprocess/transform/transform.h"
|
||||
#include "test_resource.h"
|
||||
#include "test_utils.h"
|
||||
|
||||
using namespace mmdeploy;
|
||||
using namespace mmdeploy::test;
|
||||
using namespace std;
|
||||
|
||||
void TestDefaultFormatBundle(const Value& cfg, const cv::Mat& mat) {
|
||||
auto gResource = MMDeployTestResources::Get();
|
||||
for (auto const& device_name : gResource.device_names()) {
|
||||
Device device{device_name.c_str()};
|
||||
Stream stream{device};
|
||||
auto transform = CreateTransform(cfg, device, stream);
|
||||
REQUIRE(transform != nullptr);
|
||||
|
||||
vector<cv::Mat> channel_mats(mat.channels());
|
||||
for (auto i = 0; i < mat.channels(); ++i) {
|
||||
cv::extractChannel(mat, channel_mats[i], i);
|
||||
}
|
||||
|
||||
auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}});
|
||||
REQUIRE(!res.has_error());
|
||||
auto res_tensor = res.value()["img"].get<Tensor>();
|
||||
REQUIRE(res_tensor.device() == device);
|
||||
auto shape = res_tensor.desc().shape;
|
||||
REQUIRE(shape == std::vector<int64_t>{1, mat.channels(), mat.rows, mat.cols});
|
||||
|
||||
const Device kHost{"cpu"};
|
||||
auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream);
|
||||
REQUIRE(stream.Wait());
|
||||
|
||||
// mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w}
|
||||
// compare each channel between `res_tensor` and `mat`
|
||||
auto step = shape[2] * shape[3] * mat.elemSize1();
|
||||
auto data = host_tensor.value().data<uint8_t>();
|
||||
for (auto i = 0; i < mat.channels(); ++i) {
|
||||
cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data};
|
||||
REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat));
|
||||
data += step;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_CASE("transform DefaultFormatBundle", "[img2tensor]") {
|
||||
auto gResource = MMDeployTestResources::Get();
|
||||
auto img_list = gResource.LocateImageResources("transform");
|
||||
REQUIRE(!img_list.empty());
|
||||
|
||||
auto img_path = img_list.front();
|
||||
cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR);
|
||||
cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
|
||||
cv::Mat bgr_float_mat;
|
||||
cv::Mat gray_float_mat;
|
||||
bgr_mat.convertTo(bgr_float_mat, CV_32FC3);
|
||||
gray_mat.convertTo(gray_float_mat, CV_32FC1);
|
||||
|
||||
Value cfg{{"type", "DefaultFormatBundle"}, {"keys", {"img"}}};
|
||||
vector<cv::Mat> mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat};
|
||||
for (auto& mat : mats) {
|
||||
TestDefaultFormatBundle(cfg, mat);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user