Add DefaultFormatBundle (#208)

* keep DefaultFormatBundle

* add DefaultFormatBundle

* add condition

* resolve comments

* remove useless

* add override
This commit is contained in:
AllentDan 2022-03-16 15:52:57 +08:00 committed by GitHub
parent 776659a6ce
commit ea54f3b2fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 350 additions and 12 deletions

View File

@ -9,6 +9,7 @@ set(SRCS
collect_impl.cpp
crop_impl.cpp
image2tensor_impl.cpp
default_format_bundle_impl.cpp
load_impl.cpp
normalize_impl.cpp
pad_impl.cpp

View File

@ -0,0 +1,58 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "core/utils/device_utils.h"
#include "opencv_utils.h"
#include "preprocess/transform/default_format_bundle.h"
namespace mmdeploy {
namespace cpu {
class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl {
public:
explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {}
protected:
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
auto data_type = src_tensor.desc().data_type;
if (img_to_float && data_type == DataType::kINT8) {
auto cvmat = Tensor2CVMat(src_tensor);
cvmat.convertTo(cvmat, CV_32FC(cvmat.channels()));
auto dst_tensor = CVMat2Tensor(cvmat);
return dst_tensor;
}
return src_tensor;
}
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
auto shape = src_tensor.shape();
int height = shape[1];
int width = shape[2];
int channels = shape[3];
auto dst_mat = Transpose(Tensor2CVMat(src_tensor));
auto dst_tensor = CVMat2Tensor(dst_mat);
dst_tensor.Reshape({1, channels, height, width});
return dst_tensor;
}
};
class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> {
public:
const char* GetName() const override { return "cpu"; }
int GetVersion() const override { return 1; }
ReturnType Create(const Value& args) override {
return std::make_unique<DefaultFormatBundleImpl>(args);
}
};
} // namespace cpu
} // namespace mmdeploy
using mmdeploy::DefaultFormatBundleImpl;
using mmdeploy::cpu::DefaultFormatBundleImplCreator;
REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);

View File

@ -14,6 +14,7 @@ include(${CMAKE_SOURCE_DIR}/cmake/MMDeploy.cmake)
set(SRCS
crop_impl.cpp
image2tensor_impl.cpp
default_format_bundle_impl.cpp
load_impl.cpp
normalize_impl.cpp
pad_impl.cpp

View File

@ -0,0 +1,87 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <cuda_runtime.h>
#include "core/utils/device_utils.h"
#include "preprocess/transform/default_format_bundle.h"
namespace mmdeploy {
namespace cuda {
template <int channels>
void CastToFloat(const uint8_t* src, int height, int width, float* dst, cudaStream_t stream);
template <typename T>
void Transpose(const T* src, int height, int width, int channels, T* dst, cudaStream_t stream);
class DefaultFormatBundleImpl final : public ::mmdeploy::DefaultFormatBundleImpl {
public:
explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {}
protected:
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
auto data_type = src_tensor.data_type();
auto h = tensor.shape(1);
auto w = tensor.shape(2);
auto c = tensor.shape(3);
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
if (img_to_float && data_type == DataType::kINT8) {
TensorDesc desc{device_, DataType::kFLOAT, tensor.shape(), ""};
Tensor dst_tensor{desc};
if (c == 3) {
CastToFloat<3>(src_tensor.data<uint8_t>(), h, w, dst_tensor.data<float>(), stream);
} else if (c == 1) {
CastToFloat<1>(src_tensor.data<uint8_t>(), h, w, dst_tensor.data<float>(), stream);
} else {
MMDEPLOY_ERROR("channel num: unsupported channel num {}", c);
return Status(eNotSupported);
}
return dst_tensor;
}
return src_tensor;
}
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
OUTCOME_TRY(auto src_tensor, MakeAvailableOnDevice(tensor, device_, stream_));
auto h = tensor.shape(1);
auto w = tensor.shape(2);
auto c = tensor.shape(3);
auto hw = h * w;
Tensor dst_tensor(src_tensor.desc());
dst_tensor.Reshape({1, c, h, w});
auto stream = ::mmdeploy::GetNative<cudaStream_t>(stream_);
if (DataType::kINT8 == tensor.data_type()) {
auto input = src_tensor.data<uint8_t>();
auto output = dst_tensor.data<uint8_t>();
Transpose(input, (int)h, (int)w, (int)c, output, stream);
} else if (DataType::kFLOAT == tensor.data_type()) {
auto input = src_tensor.data<float>();
auto output = dst_tensor.data<float>();
Transpose(input, (int)h, (int)w, (int)c, output, stream);
} else {
assert(0);
}
return dst_tensor;
}
};
class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> {
public:
const char* GetName() const override { return "cuda"; }
int GetVersion() const override { return 1; }
ReturnType Create(const Value& cfg) override {
return std::make_unique<DefaultFormatBundleImpl>(cfg);
}
};
} // namespace cuda
} // namespace mmdeploy
using ::mmdeploy::DefaultFormatBundleImpl;
using ::mmdeploy::cuda::DefaultFormatBundleImplCreator;
REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);

View File

@ -9,6 +9,7 @@ set(SRCS
compose.cpp
crop.cpp
image2tensor.cpp
default_format_bundle.cpp
load.cpp
normalize.cpp
pad.cpp

View File

@ -9,7 +9,7 @@ namespace mmdeploy {
class MMDEPLOY_API CollectImpl : public Module {
public:
explicit CollectImpl(const Value& args);
~CollectImpl() = default;
~CollectImpl() override = default;
Result<Value> Process(const Value& input) override;
@ -27,7 +27,7 @@ class MMDEPLOY_API CollectImpl : public Module {
class MMDEPLOY_API Collect : public Transform {
public:
explicit Collect(const Value& args, int version = 0);
~Collect() = default;
~Collect() override = default;
Result<Value> Process(const Value& input) override;

View File

@ -13,7 +13,7 @@ namespace mmdeploy {
class MMDEPLOY_API CenterCropImpl : public TransformImpl {
public:
explicit CenterCropImpl(const Value& args);
~CenterCropImpl() = default;
~CenterCropImpl() override = default;
Result<Value> Process(const Value& input) override;
@ -34,7 +34,7 @@ class MMDEPLOY_API CenterCropImpl : public TransformImpl {
class MMDEPLOY_API CenterCrop : public Transform {
public:
explicit CenterCrop(const Value& args, int version = 0);
~CenterCrop() = default;
~CenterCrop() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }

View File

@ -0,0 +1,75 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "default_format_bundle.h"
#include <cassert>
#include "archive/json_archive.h"
#include "core/tensor.h"
namespace mmdeploy {
DefaultFormatBundleImpl::DefaultFormatBundleImpl(const Value& args) : TransformImpl(args) {
if (args.contains("img_to_float") && args["img_to_float"].is_boolean()) {
arg_.img_to_float = args["img_to_float"].get<bool>();
}
}
Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
MMDEPLOY_DEBUG("DefaultFormatBundle input: {}", to_json(input).dump(2));
Value output = input;
if (input.contains("img")) {
Tensor in_tensor = input["img"].get<Tensor>();
OUTCOME_TRY(output["img"], ToFloat32(in_tensor, arg_.img_to_float));
Tensor tensor = output["img"].get<Tensor>();
// set default meta keys
if (!output.contains("pad_shape")) {
for (auto v : tensor.shape()) {
output["pad_shape"].push_back(v);
}
}
if (!output.contains("scale_factor")) {
output["scale_factor"].push_back(1.0);
}
if (!output.contains("img_norm_cfg")) {
int channel = tensor.shape()[3];
for (int i = 0; i < channel; i++) {
output["img_norm_cfg"]["mean"].push_back(0.0);
output["img_norm_cfg"]["std"].push_back(1.0);
}
output["img_norm_cfg"]["to_rgb"] = false;
}
// transpose
OUTCOME_TRY(output["img"], HWC2CHW(tensor));
}
MMDEPLOY_DEBUG("DefaultFormatBundle output: {}", to_json(output).dump(2));
return output;
}
DefaultFormatBundle::DefaultFormatBundle(const Value& args, int version) : Transform(args) {
auto impl_creator =
Registry<DefaultFormatBundleImpl>::Get().GetCreator(specified_platform_, version);
if (nullptr == impl_creator) {
MMDEPLOY_ERROR("'DefaultFormatBundle' is not supported on '{}' platform", specified_platform_);
throw std::domain_error("'DefaultFormatBundle' is not supported on specified platform");
}
impl_ = impl_creator->Create(args);
}
class DefaultFormatBundleCreator : public Creator<Transform> {
public:
const char* GetName() const override { return "DefaultFormatBundle"; }
int GetVersion() const override { return version_; }
ReturnType Create(const Value& args) override {
return std::make_unique<DefaultFormatBundle>(args, version_);
}
private:
int version_{1};
};
REGISTER_MODULE(Transform, DefaultFormatBundleCreator);
MMDEPLOY_DEFINE_REGISTRY(DefaultFormatBundleImpl);
} // namespace mmdeploy

View File

@ -0,0 +1,49 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H
#define MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H
#include "core/tensor.h"
#include "transform.h"
namespace mmdeploy {
/**
* It simplifies the pipeline of formatting common fields
*/
class MMDEPLOY_API DefaultFormatBundleImpl : public TransformImpl {
public:
DefaultFormatBundleImpl(const Value& args);
~DefaultFormatBundleImpl() override = default;
Result<Value> Process(const Value& input) override;
protected:
virtual Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) = 0;
virtual Result<Tensor> HWC2CHW(const Tensor& tensor) = 0;
protected:
struct default_format_bundle_arg_t {
bool img_to_float = true;
};
using ArgType = struct default_format_bundle_arg_t;
protected:
ArgType arg_;
};
class MMDEPLOY_API DefaultFormatBundle : public Transform {
public:
explicit DefaultFormatBundle(const Value& args, int version = 0);
~DefaultFormatBundle() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }
private:
std::unique_ptr<DefaultFormatBundleImpl> impl_;
};
MMDEPLOY_DECLARE_REGISTRY(DefaultFormatBundleImpl);
} // namespace mmdeploy
#endif // MMDEPLOY_DEFAULT_FORMAT_BUNDLE_H

View File

@ -17,7 +17,7 @@ namespace mmdeploy {
class MMDEPLOY_API ImageToTensorImpl : public TransformImpl {
public:
ImageToTensorImpl(const Value& args);
~ImageToTensorImpl() = default;
~ImageToTensorImpl() override = default;
Result<Value> Process(const Value& input) override;
@ -37,7 +37,7 @@ class MMDEPLOY_API ImageToTensorImpl : public TransformImpl {
class MMDEPLOY_API ImageToTensor : public Transform {
public:
explicit ImageToTensor(const Value& args, int version = 0);
~ImageToTensor() = default;
~ImageToTensor() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }

View File

@ -11,7 +11,7 @@ namespace mmdeploy {
class MMDEPLOY_API PrepareImageImpl : public TransformImpl {
public:
explicit PrepareImageImpl(const Value& args);
~PrepareImageImpl() = default;
~PrepareImageImpl() override = default;
Result<Value> Process(const Value& input) override;
@ -32,7 +32,7 @@ class MMDEPLOY_API PrepareImageImpl : public TransformImpl {
class MMDEPLOY_API PrepareImage : public Transform {
public:
explicit PrepareImage(const Value& args, int version = 0);
~PrepareImage() = default;
~PrepareImage() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }

View File

@ -11,7 +11,7 @@ namespace mmdeploy {
class MMDEPLOY_API NormalizeImpl : public TransformImpl {
public:
explicit NormalizeImpl(const Value& args);
~NormalizeImpl() = default;
~NormalizeImpl() override = default;
Result<Value> Process(const Value& input) override;
@ -31,7 +31,7 @@ class MMDEPLOY_API NormalizeImpl : public TransformImpl {
class MMDEPLOY_API Normalize : public Transform {
public:
explicit Normalize(const Value& args, int version = 0);
~Normalize() = default;
~Normalize() override = default;
Result<Value> Process(const Value& input) override { return impl_->Process(input); }

View File

@ -191,8 +191,6 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config):
and 'RescaleToZeroOne' not in item['type']
]
for i, transform in enumerate(transforms):
if transform['type'] == 'DefaultFormatBundle':
transforms[i] = dict(type='ImageToTensor', keys=['img'])
if 'keys' in transform and transform['keys'] == ['lq']:
transform['keys'] = ['img']
if 'key' in transform and transform['key'] == 'lq':

View File

@ -0,0 +1,68 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "catch.hpp"
#include "core/tensor.h"
#include "core/utils/device_utils.h"
#include "opencv_utils.h"
#include "preprocess/transform/transform.h"
#include "test_resource.h"
#include "test_utils.h"
using namespace mmdeploy;
using namespace mmdeploy::test;
using namespace std;
void TestDefaultFormatBundle(const Value& cfg, const cv::Mat& mat) {
auto gResource = MMDeployTestResources::Get();
for (auto const& device_name : gResource.device_names()) {
Device device{device_name.c_str()};
Stream stream{device};
auto transform = CreateTransform(cfg, device, stream);
REQUIRE(transform != nullptr);
vector<cv::Mat> channel_mats(mat.channels());
for (auto i = 0; i < mat.channels(); ++i) {
cv::extractChannel(mat, channel_mats[i], i);
}
auto res = transform->Process({{"img", cpu::CVMat2Tensor(mat)}});
REQUIRE(!res.has_error());
auto res_tensor = res.value()["img"].get<Tensor>();
REQUIRE(res_tensor.device() == device);
auto shape = res_tensor.desc().shape;
REQUIRE(shape == std::vector<int64_t>{1, mat.channels(), mat.rows, mat.cols});
const Device kHost{"cpu"};
auto host_tensor = MakeAvailableOnDevice(res_tensor, kHost, stream);
REQUIRE(stream.Wait());
// mat's shape is {h, w, c}, while res_tensor's shape is {1, c, h, w}
// compare each channel between `res_tensor` and `mat`
auto step = shape[2] * shape[3] * mat.elemSize1();
auto data = host_tensor.value().data<uint8_t>();
for (auto i = 0; i < mat.channels(); ++i) {
cv::Mat _mat{mat.rows, mat.cols, CV_MAKETYPE(mat.depth(), 1), data};
REQUIRE(::mmdeploy::cpu::Compare(channel_mats[i], _mat));
data += step;
}
}
}
TEST_CASE("transform DefaultFormatBundle", "[img2tensor]") {
auto gResource = MMDeployTestResources::Get();
auto img_list = gResource.LocateImageResources("transform");
REQUIRE(!img_list.empty());
auto img_path = img_list.front();
cv::Mat bgr_mat = cv::imread(img_path, cv::IMREAD_COLOR);
cv::Mat gray_mat = cv::imread(img_path, cv::IMREAD_GRAYSCALE);
cv::Mat bgr_float_mat;
cv::Mat gray_float_mat;
bgr_mat.convertTo(bgr_float_mat, CV_32FC3);
gray_mat.convertTo(gray_float_mat, CV_32FC1);
Value cfg{{"type", "DefaultFormatBundle"}, {"keys", {"img"}}};
vector<cv::Mat> mats{bgr_mat, gray_mat, bgr_float_mat, gray_float_mat};
for (auto& mat : mats) {
TestDefaultFormatBundle(cfg, mat);
}
}