[Feature] Add option to fuse transform. (#741)
* add collect_impl.cpp to cuda device * add dummy compute node wich device elena * add compiler & dynamic library loader * add code to compile with gen code(elena) * move folder * fix lint * add tracer module * add license * update type id * add fuse kernel registry * remove compilier & dynamic_library * update fuse kernel interface * Add elena-mmdeploy project in 3rd-party * Fix README.md * fix cmake file * Support cuda device and clang format all file * Add cudaStreamSynchronize for cudafree * fix cudaStreamSynchronize * rename to __tracer__ * remove unused code * update kernel * update extract elena script * update gitignore * fix ci * Change the crop_size to crop_h and crop_w in arglist * update Tracer * remove cond * avoid allocate memory * add build.sh for elena * remove code * update test * Support bilinear resize with float input * Rename elena-mmdeploy to delete * Introduce public submodule * use get_ref * update elena * update tools * update tools * update fuse transform docs * add fuse transform doc link to get_started * fix shape in crop * remove fuse_transform_ == true check * remove fuse_transform_ member * remove elena_int.h * doesn't dump transform_static.json * update tracer * update CVFusion to remove compile warning * remove mmcv version > 1.5.1 dep * fix tests * update docs * add elena use option * remove submodule of CVFusion * update doc * use auto * use throw_exception(eEntryNotFound); * update Co-authored-by: cx <cx@ubuntu20.04> Co-authored-by: miraclezqc <969226879@qq.com>pull/865/head
parent
ac3a12026d
commit
6b01a2e649
csrc/mmdeploy/preprocess
docs
en
02-how-to-run
zh_cn
02-how-to-run
mmdeploy/backend/sdk
tests/test_csrc/preprocess
tools/elena
|
@ -159,3 +159,8 @@ fusion_result.json
|
|||
# snpe
|
||||
grpc-cpp-plugin
|
||||
service/snpe/grpc_cpp_plugin
|
||||
|
||||
# elena-code
|
||||
csrc/mmdeploy/preprocess/elena/json
|
||||
csrc/mmdeploy/preprocess/elena/cpu_kernel/*
|
||||
csrc/mmdeploy/preprocess/elena/cuda_kernel/*
|
||||
|
|
|
@ -34,6 +34,7 @@ option(MMDEPLOY_BUILD_EXAMPLES "build examples" OFF)
|
|||
option(MMDEPLOY_SPDLOG_EXTERNAL "use external spdlog" OFF)
|
||||
option(MMDEPLOY_ZIP_MODEL "support SDK model in zip format" OFF)
|
||||
option(MMDEPLOY_COVERAGE "build SDK for coverage" OFF)
|
||||
option(MMDEPLOY_ELENA_FUSION "use elena to fuse preprocess" OFF)
|
||||
|
||||
set(MMDEPLOY_TARGET_DEVICES "cpu" CACHE STRING "target devices to support")
|
||||
set(MMDEPLOY_TARGET_BACKENDS "" CACHE STRING "target inference engines to support")
|
||||
|
|
|
@ -4,6 +4,9 @@ project(mmdeploy_transform_module)
|
|||
|
||||
add_subdirectory(transform)
|
||||
add_subdirectory(cpu)
|
||||
if (MMDEPLOY_ELENA_FUSION)
|
||||
add_subdirectory(elena)
|
||||
endif ()
|
||||
if ("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
||||
add_subdirectory(cuda)
|
||||
endif ()
|
||||
|
|
|
@ -5,6 +5,7 @@ project(mmdeploy_cuda_transform_impl CUDA CXX)
|
|||
find_package(pplcv REQUIRED)
|
||||
|
||||
set(SRCS
|
||||
collect_impl.cpp
|
||||
crop_impl.cpp
|
||||
image2tensor_impl.cpp
|
||||
default_format_bundle_impl.cpp
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/collect.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace cuda {
|
||||
|
||||
class CollectImpl : public ::mmdeploy::CollectImpl {
|
||||
public:
|
||||
CollectImpl(const Value& args) : ::mmdeploy::CollectImpl(args) {}
|
||||
~CollectImpl() = default;
|
||||
};
|
||||
|
||||
class CollectImplCreator : public Creator<::mmdeploy::CollectImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "cuda"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
std::unique_ptr<::mmdeploy::CollectImpl> Create(const Value& args) override {
|
||||
return std::make_unique<CollectImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::CollectImpl;
|
||||
using mmdeploy::cuda::CollectImplCreator;
|
||||
REGISTER_MODULE(CollectImpl, CollectImplCreator);
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
project(mmdeploy_elena_transform_impl)
|
||||
|
||||
set(SRCS
|
||||
crop_impl.cpp
|
||||
collect_impl.cpp
|
||||
image2tensor_impl.cpp
|
||||
default_format_bundle_impl.cpp
|
||||
load_impl.cpp
|
||||
normalize_impl.cpp
|
||||
pad_impl.cpp
|
||||
resize_impl.cpp
|
||||
elena_registry.cpp)
|
||||
|
||||
file(GLOB CPU_KERNEL_SRCS "cpu_kernel/*.cpp")
|
||||
|
||||
set(ALL_SRCS ${SRCS} ${CPU_KERNEL_SRCS})
|
||||
if ("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
||||
file(GLOB CUDA_KERNEL_SRCS "cuda_kernel/*.cu")
|
||||
set(ALL_SRCS ${ALL_SRCS} ${CUDA_KERNEL_SRCS})
|
||||
endif ()
|
||||
|
||||
mmdeploy_add_module(${PROJECT_NAME} "${ALL_SRCS}")
|
||||
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_link_libraries(${PROJECT_NAME}
|
||||
PRIVATE mmdeploy::transform)
|
||||
if ("cuda" IN_LIST MMDEPLOY_TARGET_DEVICES)
|
||||
target_link_libraries(${PROJECT_NAME} PRIVATE cuda)
|
||||
endif ()
|
||||
add_library(mmdeploy::transform_impl::elena ALIAS ${PROJECT_NAME})
|
|
@ -0,0 +1,145 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
|
||||
#include "elena_registry.h"
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/mat.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/utils/device_utils.h"
|
||||
#include "mmdeploy/core/utils/formatter.h"
|
||||
#include "mmdeploy/preprocess/transform/collect.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
using namespace trace;
|
||||
|
||||
struct ExtractTransParamVisitor {
|
||||
bool valid{true};
|
||||
std::set<std::string> st;
|
||||
|
||||
std::array<float, 3> mean;
|
||||
std::array<float, 3> std;
|
||||
std::array<int, 2> resize_hw;
|
||||
std::string resize_mode;
|
||||
float pad_val;
|
||||
std::array<int, 4> pad_tlbr;
|
||||
std::array<int, 2> pad_hw;
|
||||
std::array<int, 4> crop_tlbr;
|
||||
std::array<int, 2> crop_hw;
|
||||
|
||||
void CheckValid(const std::string& name) {
|
||||
if (st.count(name)) {
|
||||
valid = false;
|
||||
return;
|
||||
}
|
||||
st.insert(name);
|
||||
}
|
||||
|
||||
void operator()(CvtColorParam&) {}
|
||||
void operator()(CastParam&) {}
|
||||
void operator()(HWC2CHWParam&) {}
|
||||
|
||||
void operator()(ResizeParam& param) {
|
||||
CheckValid("Resize");
|
||||
resize_hw = {param.size[0], param.size[1]};
|
||||
resize_mode = param.mode;
|
||||
}
|
||||
void operator()(PadParam& param) {
|
||||
CheckValid("Pad");
|
||||
pad_val = param.pad_val;
|
||||
std::copy_n(param.tlbr.begin(), 4, pad_tlbr.begin());
|
||||
std::copy_n(param.size.begin(), 2, pad_hw.begin());
|
||||
}
|
||||
void operator()(NormParam& param) {
|
||||
CheckValid("Normalize");
|
||||
std::copy(param.mean.begin(), param.mean.end(), mean.begin());
|
||||
std::copy(param.std.begin(), param.std.end(), std.begin());
|
||||
}
|
||||
void operator()(CropParam& param) {
|
||||
CheckValid("CenterCrop");
|
||||
std::copy_n(param.tlbr.begin(), 4, crop_tlbr.begin());
|
||||
std::copy_n(param.size.begin(), 2, crop_hw.begin());
|
||||
}
|
||||
};
|
||||
|
||||
class CollectImpl : public ::mmdeploy::CollectImpl {
|
||||
public:
|
||||
CollectImpl(const Value& args) : ::mmdeploy::CollectImpl(args) {
|
||||
Platform platform(device_.platform_id());
|
||||
device_name_ = platform.GetPlatformName();
|
||||
sha256_ = args["context"].value("sha256", std::string(""));
|
||||
}
|
||||
|
||||
~CollectImpl() = default;
|
||||
|
||||
Result<Value> Process(const Value& input) override {
|
||||
auto tracer = input["__tracer__"].get<Tracer>();
|
||||
Mat _src_mat = input["ori_img"].get<Mat>();
|
||||
OUTCOME_TRY(auto src_mat, MakeAvailableOnDevice(_src_mat, device_, stream_));
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
|
||||
ExtractTransParamVisitor visitor{};
|
||||
for (auto&& trans : tracer.trans_) {
|
||||
std::visit(visitor, trans);
|
||||
}
|
||||
std::string tag = sha256_ + "_" + device_name_;
|
||||
FuseFunc func = FuseKernel::Get().GetFunc(tag);
|
||||
|
||||
if (!visitor.valid) {
|
||||
MMDEPLOY_ERROR("unsupported fuse transform");
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
if (src_mat.type() != DataType::kINT8) {
|
||||
MMDEPLOY_ERROR("unsupported data type in fuse transform");
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
if (!func) {
|
||||
MMDEPLOY_ERROR("can't find fuse function with tag: {}", tag);
|
||||
throw std::invalid_argument("");
|
||||
}
|
||||
|
||||
Value output = input;
|
||||
auto img_fields = GetImageFields(input);
|
||||
for (auto& key : img_fields) {
|
||||
assert(input.contains(key));
|
||||
auto src_tensor = input[key].get<Tensor>();
|
||||
auto desc = src_tensor.desc();
|
||||
desc.device = device_;
|
||||
Tensor dst_tensor{desc};
|
||||
|
||||
func(stream_.GetNative(), src_mat.data<uint8_t>(), src_mat.height(), src_mat.width(),
|
||||
to_string(src_mat.pixel_format()).c_str(), visitor.resize_hw[0], visitor.resize_hw[1],
|
||||
visitor.resize_mode.c_str(), visitor.crop_tlbr[0], visitor.crop_tlbr[1],
|
||||
visitor.crop_hw[0], visitor.crop_hw[1], visitor.mean[0], visitor.mean[1],
|
||||
visitor.mean[2], visitor.std[0], visitor.std[1], visitor.std[2], visitor.pad_tlbr[0],
|
||||
visitor.pad_tlbr[1], visitor.pad_tlbr[2], visitor.pad_tlbr[3], visitor.pad_hw[0],
|
||||
visitor.pad_hw[1], visitor.pad_val, dst_tensor.data<float>(), dst_tensor.shape(2),
|
||||
dst_tensor.shape(3));
|
||||
output[key] = std::move(dst_tensor);
|
||||
}
|
||||
return ::mmdeploy::CollectImpl::Process(output);
|
||||
}
|
||||
|
||||
std::string sha256_;
|
||||
std::string device_name_;
|
||||
};
|
||||
|
||||
class CollectImplCreator : public Creator<::mmdeploy::CollectImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
std::unique_ptr<::mmdeploy::CollectImpl> Create(const Value& args) override {
|
||||
return std::make_unique<CollectImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::CollectImpl;
|
||||
using mmdeploy::elena::CollectImplCreator;
|
||||
REGISTER_MODULE(CollectImpl, CollectImplCreator);
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/crop.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class CenterCropImpl : public ::mmdeploy::CenterCropImpl {
|
||||
public:
|
||||
explicit CenterCropImpl(const Value& args) : ::mmdeploy::CenterCropImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> CropImage(const Tensor& tensor, int top, int left, int bottom,
|
||||
int right) override {
|
||||
auto& src_desc = tensor.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
auto shape = src_desc.shape;
|
||||
shape[1] = bottom - top + 1; // h
|
||||
shape[2] = right - left + 1; // w
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class CenterCropImplCreator : public Creator<::mmdeploy::CenterCropImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override { return make_unique<CenterCropImpl>(args); }
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using ::mmdeploy::CenterCropImpl;
|
||||
using ::mmdeploy::elena::CenterCropImplCreator;
|
||||
|
||||
REGISTER_MODULE(CenterCropImpl, CenterCropImplCreator);
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/default_format_bundle.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class DefaultFormatBundleImpl : public ::mmdeploy::DefaultFormatBundleImpl {
|
||||
public:
|
||||
explicit DefaultFormatBundleImpl(const Value& args) : ::mmdeploy::DefaultFormatBundleImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> ToFloat32(const Tensor& tensor, const bool& img_to_float) override {
|
||||
auto& src_desc = tensor.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
auto shape = src_desc.shape;
|
||||
|
||||
if (img_to_float && data_type == DataType::kINT8) {
|
||||
data_type = DataType::kFLOAT;
|
||||
}
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
|
||||
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
|
||||
auto& src_desc = tensor.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
auto shape = src_desc.shape;
|
||||
shape = {shape[0], shape[3], shape[1], shape[2]};
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class DefaultFormatBundleImplCreator : public Creator<::mmdeploy::DefaultFormatBundleImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override {
|
||||
return std::make_unique<DefaultFormatBundleImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::DefaultFormatBundleImpl;
|
||||
using mmdeploy::elena::DefaultFormatBundleImplCreator;
|
||||
REGISTER_MODULE(DefaultFormatBundleImpl, DefaultFormatBundleImplCreator);
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "elena_registry.h"
|
||||
|
||||
#include "mmdeploy/core/logger.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
FuseKernel& FuseKernel::Get() {
|
||||
static FuseKernel fuse_kernel;
|
||||
return fuse_kernel;
|
||||
}
|
||||
|
||||
FuseFunc FuseKernel::GetFunc(const std::string& name) {
|
||||
if (entries_.count(name)) {
|
||||
return entries_[name];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int FuseKernel::Register(const std::string& name, FuseFunc func) {
|
||||
if (entries_.count(name)) {
|
||||
return -1;
|
||||
}
|
||||
MMDEPLOY_DEBUG("Register fuse kernel: '{}'", name);
|
||||
entries_.emplace(name, func);
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_ELENA_REGISTRY_H_
|
||||
#define MMDEPLOY_ELENA_REGISTRY_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
#include "mmdeploy/core/macro.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
using FuseFunc = void (*)(void* stream, uint8_t* data_in, int src_h, int src_w, const char* format,
|
||||
int resize_h, int resize_w, const char* interpolation, int crop_top,
|
||||
int crop_left, int crop_h, int crop_w, float mean0, float mean1,
|
||||
float mean2, float std0, float std1, float std2, int pad_top,
|
||||
int pad_left, int pad_bottom, int pad_right, int pad_h, int pad_w,
|
||||
float pad_value, float* data_out, int dst_h, int dst_w);
|
||||
|
||||
class MMDEPLOY_API FuseKernel {
|
||||
public:
|
||||
static FuseKernel& Get();
|
||||
int Register(const std::string& name, FuseFunc func);
|
||||
FuseFunc GetFunc(const std::string& name);
|
||||
|
||||
private:
|
||||
FuseKernel() = default;
|
||||
std::map<std::string, FuseFunc> entries_;
|
||||
};
|
||||
|
||||
class MMDEPLOY_API FuseKernelRegister {
|
||||
public:
|
||||
FuseKernelRegister(const std::string& name, FuseFunc func) {
|
||||
FuseKernel::Get().Register(name, func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
#define REGISTER_FUSE_KERNEL(name, module_name, func) \
|
||||
static ::mmdeploy::elena::FuseKernelRegister g_register_##name##_##func(module_name, func);
|
||||
|
||||
#endif
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/image2tensor.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class ImageToTensorImpl : public ::mmdeploy::ImageToTensorImpl {
|
||||
public:
|
||||
explicit ImageToTensorImpl(const Value& args) : ::mmdeploy::ImageToTensorImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> HWC2CHW(const Tensor& tensor) override {
|
||||
auto& src_desc = tensor.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
auto shape = src_desc.shape;
|
||||
shape = {shape[0], shape[3], shape[1], shape[2]};
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class ImageToTensorImplCreator : public Creator<::mmdeploy::ImageToTensorImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override {
|
||||
return std::make_unique<ImageToTensorImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::ImageToTensorImpl;
|
||||
using mmdeploy::elena::ImageToTensorImplCreator;
|
||||
REGISTER_MODULE(ImageToTensorImpl, ImageToTensorImplCreator);
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/load.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class PrepareImageImpl : public ::mmdeploy::PrepareImageImpl {
|
||||
public:
|
||||
explicit PrepareImageImpl(const Value& args) : ::mmdeploy::PrepareImageImpl(args){};
|
||||
~PrepareImageImpl() override = default;
|
||||
|
||||
protected:
|
||||
Result<Tensor> ConvertToBGR(const Mat& img) override {
|
||||
auto data_type = img.type();
|
||||
auto format = img.pixel_format();
|
||||
TensorShape shape = {1, img.height(), img.width(), 3};
|
||||
|
||||
if (format == PixelFormat::kNV12 || format == PixelFormat::kNV21) {
|
||||
shape[1] = shape[1] / 3 * 2;
|
||||
}
|
||||
|
||||
if (arg_.to_float32) {
|
||||
data_type = DataType::kFLOAT;
|
||||
}
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
|
||||
Result<Tensor> ConvertToGray(const Mat& img) override {
|
||||
auto data_type = img.type();
|
||||
auto format = img.pixel_format();
|
||||
TensorShape shape = {1, img.height(), img.width(), 1};
|
||||
|
||||
if (format == PixelFormat::kNV12 || format == PixelFormat::kNV21) {
|
||||
shape[1] = shape[1] / 3 * 2;
|
||||
}
|
||||
|
||||
if (arg_.to_float32) {
|
||||
data_type = DataType::kFLOAT;
|
||||
}
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class PrepareImageImplCreator : public Creator<::mmdeploy::PrepareImageImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override { return make_unique<PrepareImageImpl>(args); }
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::PrepareImageImpl;
|
||||
using mmdeploy::elena::PrepareImageImplCreator;
|
||||
REGISTER_MODULE(PrepareImageImpl, PrepareImageImplCreator);
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/normalize.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class NormalizeImpl : public ::mmdeploy::NormalizeImpl {
|
||||
public:
|
||||
NormalizeImpl(const Value& value) : ::mmdeploy::NormalizeImpl(value){};
|
||||
~NormalizeImpl() = default;
|
||||
|
||||
protected:
|
||||
Result<Tensor> NormalizeImage(const Tensor& tensor) override {
|
||||
auto& src_desc = tensor.desc();
|
||||
auto data_type = DataType::kFLOAT;
|
||||
auto shape = src_desc.shape;
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class NormalizeImplCreator : public Creator<::mmdeploy::NormalizeImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
std::unique_ptr<::mmdeploy::NormalizeImpl> Create(const Value& args) override {
|
||||
return make_unique<NormalizeImpl>(args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::NormalizeImpl;
|
||||
using mmdeploy::elena::NormalizeImplCreator;
|
||||
REGISTER_MODULE(NormalizeImpl, NormalizeImplCreator);
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/pad.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class PadImpl : public ::mmdeploy::PadImpl {
|
||||
public:
|
||||
PadImpl(const Value& args) : ::mmdeploy::PadImpl(args) {}
|
||||
|
||||
protected:
|
||||
Result<Tensor> PadImage(const Tensor& img, const std::array<int, 4>& padding) override {
|
||||
auto& src_desc = img.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
auto shape = src_desc.shape; // 1 x h x w x c
|
||||
shape[1] += padding[1] + padding[3];
|
||||
shape[2] += padding[0] + padding[2];
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class PadImplCreator : public Creator<::mmdeploy::PadImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override { return make_unique<PadImpl>(args); }
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::PadImpl;
|
||||
using mmdeploy::elena::PadImplCreator;
|
||||
REGISTER_MODULE(PadImpl, PadImplCreator);
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/resize.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace elena {
|
||||
|
||||
class ResizeImpl final : public ::mmdeploy::ResizeImpl {
|
||||
public:
|
||||
ResizeImpl(const Value& args) : ::mmdeploy::ResizeImpl(args) {}
|
||||
~ResizeImpl() = default;
|
||||
|
||||
protected:
|
||||
Result<Tensor> ResizeImage(const Tensor& img, int dst_h, int dst_w) override {
|
||||
auto& src_desc = img.desc();
|
||||
auto data_type = src_desc.data_type;
|
||||
TensorShape shape = {1, dst_h, dst_w, img.shape().back()};
|
||||
|
||||
TensorDesc dummy_desc = {Device{"cpu"}, data_type, shape};
|
||||
Tensor dummy(dummy_desc, dummy_buffer_);
|
||||
|
||||
return dummy;
|
||||
}
|
||||
Buffer dummy_buffer_{Device{"cpu"}, 0, nullptr};
|
||||
};
|
||||
|
||||
class ResizeImplCreator : public Creator<mmdeploy::ResizeImpl> {
|
||||
public:
|
||||
const char* GetName() const override { return "elena"; }
|
||||
int GetVersion() const override { return 1; }
|
||||
ReturnType Create(const Value& args) override { return std::make_unique<ResizeImpl>(args); }
|
||||
};
|
||||
|
||||
} // namespace elena
|
||||
} // namespace mmdeploy
|
||||
|
||||
using mmdeploy::ResizeImpl;
|
||||
using mmdeploy::elena::ResizeImplCreator;
|
||||
REGISTER_MODULE(ResizeImpl, ResizeImplCreator);
|
|
@ -12,7 +12,8 @@ set(SRCS
|
|||
normalize.cpp
|
||||
pad.cpp
|
||||
resize.cpp
|
||||
transform.cpp)
|
||||
transform.cpp
|
||||
tracer.cpp)
|
||||
mmdeploy_add_module(${PROJECT_NAME} LIBRARY "${SRCS}")
|
||||
target_include_directories(
|
||||
${PROJECT_NAME} PUBLIC $<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/preprocess>)
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
namespace mmdeploy {
|
||||
|
||||
CollectImpl::CollectImpl(const Value &args) {
|
||||
CollectImpl::CollectImpl(const Value &args) : TransformImpl(args) {
|
||||
if (!args.contains("keys") || !args["keys"].is_array()) {
|
||||
throw std::invalid_argument("'keys' is missed in arguments, or it is not an array as expected");
|
||||
}
|
||||
|
@ -57,7 +57,12 @@ Result<Value> CollectImpl::Process(const Value &input) {
|
|||
}
|
||||
|
||||
Collect::Collect(const Value &args, int version) : Transform(args) {
|
||||
impl_ = Registry<CollectImpl>::Get().GetCreator("cpu", version)->Create(args);
|
||||
auto impl_creator = Registry<CollectImpl>::Get().GetCreator(specified_platform_, version);
|
||||
if (nullptr == impl_creator) {
|
||||
MMDEPLOY_ERROR("'Collect' is not supported on '{}' platform", specified_platform_);
|
||||
throw_exception(eEntryNotFound);
|
||||
}
|
||||
impl_ = impl_creator->Create(args);
|
||||
}
|
||||
|
||||
Result<Value> Collect::Process(const Value &input) { return impl_->Process(input); }
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
#include "transform.h"
|
||||
namespace mmdeploy {
|
||||
|
||||
class MMDEPLOY_API CollectImpl : public Module {
|
||||
class MMDEPLOY_API CollectImpl : public TransformImpl {
|
||||
public:
|
||||
explicit CollectImpl(const Value& args);
|
||||
~CollectImpl() override = default;
|
||||
|
|
|
@ -14,6 +14,12 @@ Compose::Compose(const Value& args, int version) : Transform(args) {
|
|||
Value context;
|
||||
context = args["context"];
|
||||
context["stream"].get_to(stream_);
|
||||
bool fuse_transform = args.value("fuse_transform", false);
|
||||
if (fuse_transform) {
|
||||
std::string sha256 = args.value("sha256", std::string(""));
|
||||
context["fuse_transform"] = true;
|
||||
context["sha256"] = sha256;
|
||||
}
|
||||
for (auto cfg : args["transforms"]) {
|
||||
cfg["context"] = context;
|
||||
auto type = cfg.value("type", std::string{});
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "crop.h"
|
||||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -47,6 +48,13 @@ Result<Value> CenterCropImpl::Process(const Value& input) {
|
|||
|
||||
auto& shape = dst_tensor.desc().shape;
|
||||
|
||||
// trace static info & runtime args
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().CenterCrop(
|
||||
{y1, x1, h - (int)shape[1] - y1, w - (int)shape[2] - x1}, {(int)shape[1], (int)shape[2]},
|
||||
tensor.data_type());
|
||||
}
|
||||
|
||||
output["img_shape"] = {shape[0], shape[1], shape[2], shape[3]};
|
||||
if (input.contains("scale_factor")) {
|
||||
// image has been processed by `Resize` transform before.
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
|
@ -40,6 +41,12 @@ Result<Value> DefaultFormatBundleImpl::Process(const Value& input) {
|
|||
output["img_norm_cfg"]["to_rgb"] = false;
|
||||
}
|
||||
|
||||
// trace static info & runtime args
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().DefaultFormatBundle(arg_.img_to_float,
|
||||
in_tensor.data_type());
|
||||
}
|
||||
|
||||
// transpose
|
||||
OUTCOME_TRY(tensor, HWC2CHW(tensor));
|
||||
SetTransformData(output, "img", std::move(tensor));
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
|
@ -28,6 +29,10 @@ Result<Value> ImageToTensorImpl::Process(const Value& input) {
|
|||
|
||||
OUTCOME_TRY(auto dst, HWC2CHW(src_tensor));
|
||||
SetTransformData(output, key, std::move(dst));
|
||||
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().ImageToTensor(src_tensor.data_type());
|
||||
}
|
||||
} // for key
|
||||
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
|
||||
return output;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "load.h"
|
||||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
|
@ -52,6 +53,13 @@ Result<Value> PrepareImageImpl::Process(const Value& input) {
|
|||
|
||||
SetTransformData(output, "img", std::move(tensor));
|
||||
|
||||
// trace static info & runtime args
|
||||
Tracer tracer;
|
||||
tracer.PrepareImage(arg_.color_type, arg_.to_float32,
|
||||
{1, src_mat.height(), src_mat.width(), src_mat.channel()},
|
||||
src_mat.pixel_format(), src_mat.type());
|
||||
output["__tracer__"] = std::move(tracer);
|
||||
|
||||
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
|
||||
|
||||
return output;
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/registry.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -75,6 +76,12 @@ Result<Value> NormalizeImpl::Process(const Value& input) {
|
|||
output["img_norm_cfg"]["std"].push_back(v);
|
||||
}
|
||||
output["img_norm_cfg"]["to_rgb"] = arg_.to_rgb;
|
||||
|
||||
// trace static info & runtime args
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().Normalize(arg_.mean, arg_.std, arg_.to_rgb,
|
||||
desc.data_type);
|
||||
}
|
||||
}
|
||||
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
|
||||
return output;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "pad.h"
|
||||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -95,6 +96,13 @@ Result<Value> PadImpl::Process(const Value& input) {
|
|||
output["pad_shape"].push_back(v);
|
||||
}
|
||||
|
||||
// trace static info & runtime args
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().Pad(
|
||||
arg_.pad_val, {padding[1], padding[0], padding[3], padding[2]},
|
||||
{(int)output_tensor.shape(1), (int)output_tensor.shape(2)}, output_tensor.data_type());
|
||||
}
|
||||
|
||||
SetTransformData(output, key, std::move(output_tensor));
|
||||
}
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "mmdeploy/archive/json_archive.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
|
@ -109,6 +110,12 @@ Result<Value> ResizeImpl::Process(const Value& input) {
|
|||
output["keep_ratio"] = arg_.keep_ratio;
|
||||
|
||||
SetTransformData(output, key, std::move(dst_img));
|
||||
|
||||
// trace static info & runtime args
|
||||
if (output.contains("__tracer__")) {
|
||||
output["__tracer__"].get_ref<Tracer&>().Resize(arg_.interpolation, {dst_h, dst_w},
|
||||
src_img.data_type());
|
||||
}
|
||||
}
|
||||
|
||||
MMDEPLOY_DEBUG("output: {}", to_json(output).dump(2));
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#include "mmdeploy/preprocess/transform/tracer.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
|
||||
using namespace trace;
|
||||
|
||||
void Tracer::PrepareImage(const std::string &color_type, bool to_float32, TensorShape shape,
|
||||
PixelFormat pfmt, DataType dtype) {
|
||||
PixelFormat pdst = PixelFormat::kGRAYSCALE;
|
||||
if (color_type == "color" || color_type == "color_ignore_orientation") {
|
||||
pdst = PixelFormat::kBGR;
|
||||
}
|
||||
trans_.push_back(CvtColorParam{dtype, pfmt, pdst});
|
||||
state_ = {dtype, pdst, shape};
|
||||
|
||||
if (to_float32) {
|
||||
trans_.push_back(CastParam{dtype, DataType::kFLOAT});
|
||||
state_.dtype = DataType::kFLOAT;
|
||||
common_dtype_ = DataType::kFLOAT;
|
||||
}
|
||||
}
|
||||
|
||||
void Tracer::Resize(const std::string &mode, const std::vector<int> &size, DataType dtype) {
|
||||
trans_.push_back(ResizeParam{dtype, size, mode});
|
||||
state_.shape[1] = size[0];
|
||||
state_.shape[2] = size[2];
|
||||
}
|
||||
|
||||
void Tracer::Pad(float pad_val, const std::vector<int> &tlbr, const std::vector<int> &size,
|
||||
DataType dtype) {
|
||||
trans_.push_back(PadParam{dtype, pad_val, tlbr, size});
|
||||
state_.shape[1] = size[0];
|
||||
state_.shape[2] = size[2];
|
||||
}
|
||||
|
||||
void Tracer::Normalize(const std::vector<float> &mean, const std::vector<float> &std, bool to_rgb,
|
||||
DataType dtype) {
|
||||
if (common_dtype_ == std::nullopt || common_dtype_.value() != DataType::kFLOAT) {
|
||||
trans_.push_back(CastParam{dtype, DataType::kFLOAT});
|
||||
state_.dtype = DataType::kFLOAT;
|
||||
common_dtype_ = DataType::kFLOAT;
|
||||
}
|
||||
|
||||
if (to_rgb) {
|
||||
trans_.push_back(CvtColorParam{DataType::kFLOAT, state_.pfmt, PixelFormat::kRGB});
|
||||
state_.pfmt = PixelFormat::kRGB;
|
||||
}
|
||||
|
||||
trans_.push_back(NormParam{state_.dtype, mean, std});
|
||||
}
|
||||
|
||||
void Tracer::CenterCrop(const std::vector<int> &tlbr, const std::vector<int> &size,
|
||||
DataType dtype) {
|
||||
trans_.push_back(CropParam{state_.dtype, tlbr, size});
|
||||
state_.shape[1] = size[0];
|
||||
state_.shape[2] = size[2];
|
||||
}
|
||||
|
||||
void Tracer::DefaultFormatBundle(bool to_float, DataType dtype) {
|
||||
if (to_float && (common_dtype_ == std::nullopt || common_dtype_.value() != DataType::kFLOAT)) {
|
||||
trans_.push_back(CastParam{dtype, DataType::kFLOAT});
|
||||
state_.dtype = DataType::kFLOAT;
|
||||
common_dtype_ = DataType::kFLOAT;
|
||||
}
|
||||
|
||||
trans_.push_back(HWC2CHWParam{state_.dtype});
|
||||
state_.shape = {state_.shape[0], state_.shape[3], state_.shape[1], state_.shape[2]};
|
||||
}
|
||||
|
||||
void Tracer::ImageToTensor(DataType dtype) {
|
||||
trans_.push_back(HWC2CHWParam{state_.dtype});
|
||||
state_.shape = {state_.shape[0], state_.shape[3], state_.shape[1], state_.shape[2]};
|
||||
}
|
||||
|
||||
} // namespace mmdeploy
|
|
@ -0,0 +1,102 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#ifndef MMDEPLOY_SRC_CORE_TRACER_H_
|
||||
#define MMDEPLOY_SRC_CORE_TRACER_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "mmdeploy/core/macro.h"
|
||||
#include "mmdeploy/core/registry.h"
|
||||
#include "mmdeploy/core/tensor.h"
|
||||
#include "mmdeploy/core/types.h"
|
||||
|
||||
namespace mmdeploy {
|
||||
namespace trace {
|
||||
|
||||
struct CvtColorParam {
|
||||
DataType dtype;
|
||||
PixelFormat srt;
|
||||
PixelFormat dst;
|
||||
};
|
||||
|
||||
struct CastParam {
|
||||
DataType srt;
|
||||
DataType dst;
|
||||
};
|
||||
|
||||
struct ResizeParam {
|
||||
DataType dtype;
|
||||
std::vector<int> size;
|
||||
std::string mode;
|
||||
};
|
||||
|
||||
struct CropParam {
|
||||
DataType dtype;
|
||||
std::vector<int> tlbr;
|
||||
std::vector<int> size;
|
||||
};
|
||||
|
||||
struct NormParam {
|
||||
DataType dtype;
|
||||
std::vector<float> mean;
|
||||
std::vector<float> std;
|
||||
};
|
||||
|
||||
struct PadParam {
|
||||
DataType dtype;
|
||||
float pad_val;
|
||||
std::vector<int> tlbr;
|
||||
std::vector<int> size;
|
||||
};
|
||||
|
||||
struct HWC2CHWParam {
|
||||
DataType dtype;
|
||||
};
|
||||
|
||||
using TransParamType = std::variant<CvtColorParam, CastParam, ResizeParam, PadParam, NormParam,
|
||||
CropParam, HWC2CHWParam>;
|
||||
|
||||
} // namespace trace
|
||||
|
||||
class MMDEPLOY_API Tracer {
|
||||
public:
|
||||
void Resize(const std::string &mode, const std::vector<int> &size, DataType dtype);
|
||||
|
||||
void PrepareImage(const std::string &color_type, bool to_float32, TensorShape shape,
|
||||
PixelFormat pfmt, DataType dtype);
|
||||
|
||||
void Pad(float pad_val, const std::vector<int> &tlbr, const std::vector<int> &size,
|
||||
DataType dtype);
|
||||
|
||||
void Normalize(const std::vector<float> &mean, const std::vector<float> &std, bool to_rgb,
|
||||
DataType dtype);
|
||||
|
||||
void CenterCrop(const std::vector<int> &tlbr, const std::vector<int> &size, DataType dtype);
|
||||
|
||||
void DefaultFormatBundle(bool to_float, DataType dtype);
|
||||
|
||||
void ImageToTensor(DataType dtype);
|
||||
|
||||
public:
|
||||
struct state_t {
|
||||
DataType dtype;
|
||||
PixelFormat pfmt;
|
||||
TensorShape shape;
|
||||
};
|
||||
using StateType = struct state_t;
|
||||
StateType state_;
|
||||
std::optional<DataType> common_dtype_;
|
||||
std::vector<trace::TransParamType> trans_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct is_cast_by_erasure<Tracer> : std::true_type {};
|
||||
|
||||
MMDEPLOY_REGISTER_TYPE_ID(Tracer, 9);
|
||||
|
||||
} // namespace mmdeploy
|
||||
|
||||
#endif // MMDEPLOY_SRC_CORE_TRACER_H_
|
|
@ -36,6 +36,11 @@ Transform::Transform(const Value &args) {
|
|||
Device device{"cpu"};
|
||||
if (args.contains("context")) {
|
||||
device = args["context"].value("device", device);
|
||||
bool fuse_transform = args["context"].value("fuse_transform", false);
|
||||
if (fuse_transform) {
|
||||
specified_platform_ = "elena";
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
Platform platform(device.platform_id());
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Fuse Transform(Experimental)
|
||||
|
||||
MMDeploy provides ability to fuse transform for acceleration in some cases.
|
||||
|
||||
When make inference with SDK, one can edit the pipeline.json to turn on the fuse option.
|
||||
|
||||
To bring the ability of fuse transform to MMDeploy, you can refer to the use of CVFusion.
|
||||
|
||||
## 1. Use CVFusion
|
||||
|
||||
There are two ways to use CVFusion, one is to use the pre-generated kernel code, the other is to generate the code yourself.
|
||||
|
||||
A)Use pre-generated kernel code
|
||||
|
||||
i) Download the kernel code from here,unzip it and copy the csrc folder to the mmdeploy root folder.
|
||||
|
||||
[elena_kernel-20220823.tar.gz](https://github.com/open-mmlab/mmdeploy/files/9399795/elena_kernel-20220823.tar.gz)
|
||||
|
||||
ii) Add option `-DMMDEPLOY_ELENA_FUSION=ON` when compile MMDeploy.
|
||||
|
||||
B) Generate kernel code by yourself
|
||||
|
||||
i) Compile CVFusion
|
||||
|
||||
```bash
|
||||
$ git clone --recursive https://github.com/OpenComputeLab/CVFusion.git
|
||||
$ cd CVFusion
|
||||
$ bash build.sh
|
||||
```
|
||||
|
||||
```
|
||||
# add OpFuse to PATH
|
||||
$ export PATH=`pwd`/build/examples/MMDeploy:$PATH
|
||||
```
|
||||
|
||||
ii) Download algorithm codebase
|
||||
|
||||
```bash
|
||||
$ tree -L 1 .
|
||||
├── mmdeploy
|
||||
├── mmclassification
|
||||
├── mmdetection
|
||||
├── mmsegmentation
|
||||
├── ...
|
||||
```
|
||||
|
||||
iii) Generate kernel code
|
||||
|
||||
```bash
|
||||
python tools/elena/extract_transform.py ..
|
||||
# The generated code will be saved to csrc/preprocess/elena/{cpu_kernel}/{cuda_kernel}
|
||||
```
|
||||
|
||||
iv) Add option `-DMMDEPLOY_ELENA_FUSION=ON` when compile MMDeploy.
|
||||
|
||||
## 2. Model conversion
|
||||
|
||||
Add `--dump-info` argument when convert a model, this will generate files that SDK needs.
|
||||
|
||||
```bash
|
||||
$ export MODEL_CONFIG=/path/to/mmclassification/configs/resnet/resnet18_8xb32_in1k.py
|
||||
$ export MODEL_PATH=https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
|
||||
|
||||
$ python tools/deploy.py \
|
||||
configs/mmcls/classification_onnxruntime_static.py \
|
||||
$MODEL_CONFIG \
|
||||
$MODEL_PATH \
|
||||
tests/data/tiger.jpeg \
|
||||
--work-dir resnet18 \
|
||||
--device cpu \
|
||||
--dump-info
|
||||
```
|
||||
|
||||
## 3. Model Inference
|
||||
|
||||
If the model preprocess supports fusion, there will be a filed named `fuse_transform` in `pipeline.json`. It represents fusion switch and the default value `false` stands for off. One need to edit this filed to `true` to use the fuse option.
|
|
@ -326,6 +326,10 @@ For more SDK C++ API usages, please read these [samples](https://github.com/open
|
|||
For the rest C, C# and Java API usages, please read [C demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/csrc), [C# demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/csharp) and [Java demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/java) respectively.
|
||||
We'll talk about them more in our next release.
|
||||
|
||||
#### Accelerate preprocessing(Experimental)
|
||||
|
||||
If you want to fuse preprocess for acceleration,please refer to this [doc](./02-how-to-run/fuse_transform.md)
|
||||
|
||||
## Evaluate Model
|
||||
|
||||
You can test the performance of deployed model using `tool/test.py`. For example,
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# 融合预处理(实验性功能)
|
||||
|
||||
MMDeploy提供了一些Transform融合的能力,当使用SDK进行推理时,可以通过修改pipeline.json来开启融合选项,在某些Transform的组合下可以对预处理进行加速。
|
||||
|
||||
若要在MMDeploy的SDK中加入融合能力,可参考CVFusion的使用。
|
||||
|
||||
## 一、使用CVFusion
|
||||
|
||||
有两种选择,一种是在编译mmdeploy的时候,使用我们提供的融合kernel代码,一种是自己使用CVFusion生成融合kernel的代码。
|
||||
|
||||
A)使用提供的kernel代码
|
||||
|
||||
1. 从这里下载代码,并解压,将csrc文件夹拷贝到mmdeploy的根目录。
|
||||
|
||||
[elena_kernel-20220823.tar.gz](https://github.com/open-mmlab/mmdeploy/files/9399795/elena_kernel-20220823.tar.gz)
|
||||
|
||||
2. 编译mmdeploy的时候,增加选项`-DMMDEPLOY_ELENA_FUSION=ON`
|
||||
|
||||
B) 使用CVFusion生成kernel
|
||||
|
||||
1. 编译CVFusion
|
||||
|
||||
```bash
|
||||
$ git clone --recursive https://github.com/OpenComputeLab/CVFusion.git
|
||||
$ cd CVFusion
|
||||
$ bash build.sh
|
||||
|
||||
# add OpFuse to PATH
|
||||
$ export PATH=`pwd`/build/examples/MMDeploy:$PATH
|
||||
```
|
||||
|
||||
2. 下载各个算法codebase
|
||||
|
||||
```bash
|
||||
$ tree -L 1 .
|
||||
├── mmdeploy
|
||||
├── mmclassification
|
||||
├── mmdetection
|
||||
├── mmsegmentation
|
||||
├── ...
|
||||
```
|
||||
|
||||
3. 生成融合kernel
|
||||
|
||||
```bash
|
||||
python tools/elena/extract_transform.py ..
|
||||
# 生成的代码会保存在csrc/preprocess/elena/{cpu_kernel}/{cuda_kernel}
|
||||
```
|
||||
|
||||
4. 编译mmdeploy的时候,增加选项`-DMMDEPLOY_ELENA_FUSION=ON`
|
||||
|
||||
## 二、模型转换
|
||||
|
||||
模型转换时通过`--dump-info`生成SDK所需文件。
|
||||
|
||||
```bash
|
||||
$ export MODEL_CONFIG=/path/to/mmclassification/configs/resnet/resnet18_8xb32_in1k.py
|
||||
$ export MODEL_PATH=https://download.openmmlab.com/mmclassification/v0/resnet/resnet18_8xb32_in1k_20210831-fbbb1da6.pth
|
||||
|
||||
$ python tools/deploy.py \
|
||||
configs/mmcls/classification_onnxruntime_static.py \
|
||||
$MODEL_CONFIG \
|
||||
$MODEL_PATH \
|
||||
tests/data/tiger.jpeg \
|
||||
--work-dir resnet18 \
|
||||
--device cpu \
|
||||
--dump-info
|
||||
|
||||
```
|
||||
|
||||
## 三、模型推理
|
||||
|
||||
若当前pipeline的预处理模块支持融合,`pipeline.json`中会有`fuse_transform`字段,表示融合开关,默认为`false`。当启用融合算法时,需要把`false`改为`true`
|
|
@ -327,6 +327,10 @@ target_link_libraries(${name} PRIVATE mmdeploy ${OpenCV_LIBS})
|
|||
对于 C API、C# API、Java API 的使用方法,请分别阅读代码[C demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/csrc), [C# demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/csharp) 和 [Java demos](https://github.com/open-mmlab/mmdeploy/tree/master/demo/java)。
|
||||
我们将在后续版本中详细讲述它们的用法。
|
||||
|
||||
#### 加速预处理(实验性功能)
|
||||
|
||||
若要对预处理进行加速,请查阅[此处](./02-how-to-run/fuse_transform.md)
|
||||
|
||||
## 模型精度评估
|
||||
|
||||
为了测试部署模型的精度,推理效率,我们提供了 `tools/test.py` 来帮助完成相关工作。以上文中的部署模型为例:
|
||||
|
|
|
@ -10,6 +10,7 @@ from mmdeploy.utils import (Backend, Task, get_backend, get_codebase,
|
|||
get_common_config, get_ir_config, get_root_logger,
|
||||
get_task_type, is_dynamic_batch, load_config)
|
||||
from mmdeploy.utils.constants import SDK_TASK_MAP as task_map
|
||||
from .tracer import add_transform_tag, get_transform_static
|
||||
|
||||
|
||||
def get_mmdpeloy_version() -> str:
|
||||
|
@ -400,6 +401,9 @@ def export2SDK(deploy_cfg: Union[str, mmcv.Config],
|
|||
deploy_info = get_deploy(deploy_cfg, model_cfg, work_dir, device)
|
||||
pipeline_info = get_pipeline(deploy_cfg, model_cfg, work_dir, device)
|
||||
detail_info = get_detail(deploy_cfg, model_cfg, pth=pth)
|
||||
transform_static, tag = get_transform_static(
|
||||
pipeline_info['pipeline']['tasks'][0]['transforms'])
|
||||
pipeline_info = add_transform_tag(pipeline_info, tag)
|
||||
mmcv.dump(
|
||||
deploy_info,
|
||||
'{}/deploy.json'.format(work_dir),
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import json
|
||||
from hashlib import sha256
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
|
||||
class TraceFunc:
|
||||
"""Trace Transform."""
|
||||
|
||||
def __init__(self):
|
||||
self.module_dict = dict()
|
||||
|
||||
def register_module(self, name):
|
||||
if name in self.module_dict:
|
||||
raise KeyError(f'{name} is already registered')
|
||||
|
||||
def _register(func):
|
||||
self.module_dict[name] = func
|
||||
return func
|
||||
|
||||
return _register
|
||||
|
||||
def get(self, name):
|
||||
return self.module_dict[name]
|
||||
|
||||
|
||||
_TRANSFORM_WRAPPER = TraceFunc()
|
||||
|
||||
|
||||
class Context:
|
||||
"""Trace Context."""
|
||||
|
||||
def __init__(self):
|
||||
self.dtype = None
|
||||
self.transforms = []
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='LoadImageFromFile')
|
||||
def load(context: Context, args: Dict):
|
||||
default_args = {'to_float32': False, 'color_type': 'color'}
|
||||
color_type = args.get('color_type', default_args['color_type'])
|
||||
if color_type == 'color' or \
|
||||
color_type == 'color_ignore_orientation':
|
||||
context.transforms.append({'type': 'cvtColorBGR'})
|
||||
else:
|
||||
context.transforms.append({'type': 'cvtColorGray'})
|
||||
to_float32 = args.get('to_float32', default_args['to_float32'])
|
||||
if to_float32 is True:
|
||||
context.transforms.append({'type': 'CastFloat'})
|
||||
context.dtype = 'float32'
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='DefaultFormatBundle')
|
||||
def default_format_bundle(context: Context, args: Dict):
|
||||
default_args = {'img_to_float': True}
|
||||
img_to_float = args.get('img_to_float', default_args['img_to_float'])
|
||||
if img_to_float and (context.dtype is None or context.dtype != 'float32'):
|
||||
context.transforms.append({'type': 'CastFloat'})
|
||||
context.dtype = 'float32'
|
||||
context.transforms.append({'type': 'HWC2CHW'})
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='Resize')
|
||||
def resize(context: Context, args: Dict):
|
||||
context.transforms.append({'type': 'Resize'})
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='CenterCrop')
|
||||
def center_crop(context: Context, args: Dict):
|
||||
context.transforms.append({'type': 'CenterCrop'})
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='Normalize')
|
||||
def normalize(context: Context, args: Dict):
|
||||
default_args = {'to_rgb': True}
|
||||
if context.dtype is None or context.dtype != 'float32':
|
||||
context.transforms.append({'type': 'CastFloat'})
|
||||
context.dtype = 'float32'
|
||||
to_rgb = args.get('to_rgb', default_args['to_rgb'])
|
||||
if to_rgb is True:
|
||||
context.transforms.append({'type': 'cvtColorRGB'})
|
||||
context.transforms.append({'type': 'Normalize'})
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='ImageToTensor')
|
||||
def image_to_tensor(context: Context, args: Dict):
|
||||
context.transforms.append({'type': 'HWC2CHW'})
|
||||
return True
|
||||
|
||||
|
||||
@_TRANSFORM_WRAPPER.register_module(name='Pad')
|
||||
def pad(context: Context, args: Dict):
|
||||
if context.dtype != 'float32':
|
||||
return False
|
||||
context.transforms.append({'type': 'Pad'})
|
||||
return True
|
||||
|
||||
|
||||
def add_transform_tag(pipeline_info: Dict, tag: str) -> Dict:
|
||||
if tag is None:
|
||||
return pipeline_info
|
||||
pipeline_info['pipeline']['tasks'][0]['sha256'] = tag
|
||||
pipeline_info['pipeline']['tasks'][0]['fuse_transform'] = False
|
||||
return pipeline_info
|
||||
|
||||
|
||||
def get_transform_static(transforms: List) -> Tuple:
|
||||
"""Get the static transform information for Elena use.
|
||||
|
||||
Args:
|
||||
transforms (List): transforms in model_cfg
|
||||
|
||||
Return:
|
||||
tuple(): Composed of the static transform information and the tag.
|
||||
"""
|
||||
|
||||
# Current only support basic transform
|
||||
supported_type = [
|
||||
'LoadImageFromFile', 'DefaultFormatBundle', 'Resize', 'CenterCrop',
|
||||
'Normalize', 'ImageToTensor', 'Collect', 'Pad'
|
||||
]
|
||||
|
||||
# each transform can only appear once
|
||||
cnt = {}
|
||||
for trans in transforms:
|
||||
tp = trans['type']
|
||||
if tp not in supported_type:
|
||||
return None, None
|
||||
if tp in cnt:
|
||||
return None, None
|
||||
cnt[tp] = 1
|
||||
|
||||
context = Context()
|
||||
for trans in transforms:
|
||||
tp = trans['type']
|
||||
if tp == 'Collect':
|
||||
continue
|
||||
args = trans
|
||||
func = _TRANSFORM_WRAPPER.get(tp)
|
||||
if func(context, args) is False:
|
||||
return None, None
|
||||
|
||||
if context.dtype != 'float32':
|
||||
return None, None
|
||||
|
||||
tag = sha256(json.dumps(context.transforms).encode('utf-8')).hexdigest()
|
||||
return context.transforms, tag
|
|
@ -8,26 +8,40 @@ using namespace mmdeploy;
|
|||
using namespace std;
|
||||
|
||||
TEST_CASE("test collect constructor", "[collect]") {
|
||||
Device device{"cpu"};
|
||||
Stream stream{device};
|
||||
Value cfg = {{"context", {{"device", device}, {"stream", stream}}}};
|
||||
|
||||
std::string transform_type{"Collect"};
|
||||
auto creator = Registry<Transform>::Get().GetCreator(transform_type, 1);
|
||||
REQUIRE(creator != nullptr);
|
||||
|
||||
REQUIRE_THROWS(creator->Create({}));
|
||||
REQUIRE_THROWS(creator->Create(cfg));
|
||||
|
||||
SECTION("args with 'keys' which is not an array") {
|
||||
REQUIRE_THROWS(creator->Create({{"keys", "img"}}));
|
||||
auto _cfg = cfg;
|
||||
_cfg["keys"] = "img";
|
||||
REQUIRE_THROWS(creator->Create(_cfg));
|
||||
}
|
||||
|
||||
SECTION("args with keys in array") {
|
||||
auto module = creator->Create({{"keys", {"img"}}});
|
||||
auto _cfg = cfg;
|
||||
_cfg["keys"] = {"img"};
|
||||
auto module = creator->Create(_cfg);
|
||||
REQUIRE(module != nullptr);
|
||||
}
|
||||
|
||||
SECTION("args with meta_keys that is not an array") {
|
||||
REQUIRE_THROWS(creator->Create({{"keys", {"img"}}, {"meta_keys", "ori_img"}}));
|
||||
auto _cfg = cfg;
|
||||
_cfg["keys"] = {"img"};
|
||||
_cfg["meta_keys"] = "ori_img";
|
||||
REQUIRE_THROWS(creator->Create(_cfg));
|
||||
}
|
||||
SECTION("args with meta_keys in array") {
|
||||
auto module = creator->Create({{"keys", {"img"}}, {"meta_keys", {"ori_img"}}});
|
||||
auto _cfg = cfg;
|
||||
_cfg["keys"] = {"img"};
|
||||
_cfg["meta_keys"] = {"ori_img"};
|
||||
auto module = creator->Create(_cfg);
|
||||
REQUIRE(module != nullptr);
|
||||
}
|
||||
}
|
||||
|
@ -38,6 +52,10 @@ TEST_CASE("test collect", "[collect]") {
|
|||
vector<std::string> meta_keys{"filename", "ori_filename", "ori_shape", "img_shape",
|
||||
"flip", "flip_direction", "img_norm_cfg"};
|
||||
Value args;
|
||||
Device device{"cpu"};
|
||||
Stream stream{device};
|
||||
args["context"]["device"] = device;
|
||||
args["context"]["stream"] = stream;
|
||||
for (auto& key : keys) {
|
||||
args["keys"].push_back(key);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import pathlib
|
||||
import shutil
|
||||
import subprocess
|
||||
from glob import glob
|
||||
|
||||
import mmcv
|
||||
import yaml
|
||||
|
||||
from mmdeploy.backend.sdk.export_info import (get_preprocess,
|
||||
get_transform_static)
|
||||
from mmdeploy.utils import get_root_logger, load_config
|
||||
|
||||
print(pathlib.Path(__file__).resolve())
|
||||
MMDEPLOY_PATH = pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
ELENA_BIN = 'OpFuse'
|
||||
logger = get_root_logger()
|
||||
|
||||
CODEBASE = [
|
||||
'mmclassification', 'mmdetection', 'mmpose', 'mmrotate', 'mmocr',
|
||||
'mmsegmentation', 'mmediting'
|
||||
]
|
||||
|
||||
DEPLOY_CFG = {
|
||||
'Image Classification': 'configs/mmcls/classification_tensorrt_dynamic-224x224-224x224.py',
|
||||
'Object Detection': 'configs/mmdet/detection/detection_tensorrt_static-800x1344.py',
|
||||
'Instance Segmentation': 'configs/mmdet/instance-seg/instance-seg_tensorrt_static-800x1344.py',
|
||||
'Semantic Segmentation': 'configs/mmseg/segmentation_tensorrt_static-512x512.py',
|
||||
'Oriented Object Detection': 'configs/mmrotate/rotated-detection_tensorrt-fp16_dynamic-320x320-1024x1024.py',
|
||||
'Text Recognition': 'configs/mmocr/text-recognition/text-recognition_tensorrt_static-32x32.py',
|
||||
'Text Detection': 'configs/mmocr/text-detection/text-detection_tensorrt_static-512x512.py',
|
||||
'Restorers': 'configs/mmedit/super-resolution/super-resolution_tensorrt_static-256x256.py'
|
||||
} # yapf: disable
|
||||
|
||||
INFO = {
|
||||
'cpu':
|
||||
'''
|
||||
using std::string;
|
||||
|
||||
void FuseFunc(void* stream, uint8_t* data_in, int src_h, int src_w, const char* format,
|
||||
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
|
||||
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
|
||||
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
|
||||
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
|
||||
const char* interpolation_ = "nearest";
|
||||
if (strcmp(interpolation, "bilinear") == 0) {
|
||||
interpolation_ = "bilinear";
|
||||
}
|
||||
FuseKernel(resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0, std1, std2,
|
||||
pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in, data_out,
|
||||
src_h, src_w, format, interpolation_);
|
||||
}
|
||||
|
||||
REGISTER_FUSE_KERNEL(#TAG#_cpu, "#TAG#_cpu",
|
||||
FuseFunc);
|
||||
''',
|
||||
'cuda':
|
||||
'''
|
||||
void FuseFunc(void* stream, uint8_t* data_in, int src_h, int src_w, const char* format,
|
||||
int resize_h, int resize_w, const char* interpolation, int crop_top, int crop_left,
|
||||
int crop_h, int crop_w, float mean0, float mean1, float mean2, float std0, float std1,
|
||||
float std2, int pad_top, int pad_left, int pad_bottom, int pad_right, int pad_h,
|
||||
int pad_w, float pad_value, float* data_out, int dst_h, int dst_w) {
|
||||
cudaStream_t stream_ = (cudaStream_t)stream;
|
||||
const char* interpolation_ = "nearest";
|
||||
if (strcmp(interpolation, "bilinear") == 0) {
|
||||
interpolation_ = "bilinear";
|
||||
}
|
||||
|
||||
FuseKernelCU(stream_, resize_h, resize_w, crop_h, crop_w, crop_top, crop_left, mean0, mean1, mean2, std0,
|
||||
std1, std2, pad_h, pad_w, pad_top, pad_left, pad_bottom, pad_right, pad_value, data_in,
|
||||
data_out, dst_h, dst_w, src_h, src_w, format, interpolation_);
|
||||
}
|
||||
|
||||
REGISTER_FUSE_KERNEL(#TAG#_cuda, "#TAG#_cuda",
|
||||
FuseFunc);
|
||||
'''
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Extract transform.')
|
||||
parser.add_argument(
|
||||
'root_path', help='parent path to codebase(mmdetection for example)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def append_info(device, tag):
|
||||
info = INFO[device]
|
||||
info = info.replace('#TAG#', tag)
|
||||
src_file = 'source.c' if device == 'cpu' else 'source.cu'
|
||||
nsp = f'namespace {device}_{tag}' + ' {\n'
|
||||
with open(src_file, 'r', encoding='utf-8') as f:
|
||||
data = f.readlines()
|
||||
for i, line in enumerate(data):
|
||||
if '_Kernel' in line or '__device__' in line:
|
||||
data.insert(i, nsp)
|
||||
data.insert(i, '#include "elena_registry.h"\n')
|
||||
break
|
||||
for i, line in enumerate(data):
|
||||
data[i] = line.replace('extern "C"', '')
|
||||
data.append(info)
|
||||
data.append('}')
|
||||
with open(src_file, 'w', encoding='utf-8') as f:
|
||||
for line in data:
|
||||
f.write(line)
|
||||
|
||||
|
||||
def generate_source_code(preprocess, transform_static, tag, args):
|
||||
kernel_base_dir = osp.join(MMDEPLOY_PATH, 'csrc', 'mmdeploy', 'preprocess',
|
||||
'elena')
|
||||
cpu_work_dir = osp.join(kernel_base_dir, 'cpu_kernel')
|
||||
cuda_work_dir = osp.join(kernel_base_dir, 'cuda_kernel')
|
||||
dst_cpu_kernel_file = osp.join(cpu_work_dir, f'{tag}.cpp')
|
||||
dst_cuda_kernel_file = osp.join(cuda_work_dir, f'{tag}.cu')
|
||||
dst_cpu_elena_header_file = osp.join(cpu_work_dir, 'elena_int.h')
|
||||
dst_cuda_elena_header_file = osp.join(cuda_work_dir, 'elena_int.h')
|
||||
json_work_dir = osp.join(kernel_base_dir, 'json')
|
||||
|
||||
preprocess_json_path = osp.join(json_work_dir, f'{tag}_preprocess.json')
|
||||
static_json_path = osp.join(json_work_dir, f'{tag}_static.json')
|
||||
if osp.exists(preprocess_json_path):
|
||||
return
|
||||
mmcv.dump(preprocess, preprocess_json_path, sort_keys=False, indent=4)
|
||||
mmcv.dump(transform_static, static_json_path, sort_keys=False, indent=4)
|
||||
gen_cpu_cmd = f'{ELENA_BIN} {static_json_path} cpu'
|
||||
res = subprocess.run(gen_cpu_cmd, shell=True)
|
||||
if res.returncode == 0:
|
||||
append_info('cpu', tag)
|
||||
shutil.copyfile('source.c', dst_cpu_kernel_file)
|
||||
shutil.copyfile('elena_int.h', dst_cpu_elena_header_file)
|
||||
os.remove('source.c')
|
||||
gen_cuda_cmd = f'{ELENA_BIN} {static_json_path} cuda'
|
||||
res = subprocess.run(gen_cuda_cmd, shell=True)
|
||||
if res.returncode == 0:
|
||||
append_info('cuda', tag)
|
||||
shutil.copyfile('source.cu', dst_cuda_kernel_file)
|
||||
shutil.copyfile('elena_int.h', dst_cuda_elena_header_file)
|
||||
os.remove('source.cu')
|
||||
os.remove('elena_int.h')
|
||||
|
||||
|
||||
def extract_one_model(deploy_cfg_, model_cfg_, args):
|
||||
deploy_cfg, model_cfg = load_config(deploy_cfg_, model_cfg_)
|
||||
preprocess = get_preprocess(deploy_cfg, model_cfg, 'cuda')
|
||||
preprocess['model_cfg'] = model_cfg_
|
||||
transform_static, tag = get_transform_static(preprocess['transforms'])
|
||||
if tag is not None:
|
||||
generate_source_code(preprocess, transform_static, tag, args)
|
||||
|
||||
|
||||
def extract_one_metafile(metafile, codebase, args):
|
||||
with open(metafile, encoding='utf-8') as f:
|
||||
yaml_info = yaml.load(f, Loader=yaml.FullLoader)
|
||||
known_task = list(DEPLOY_CFG.keys())
|
||||
for model in yaml_info['Models']:
|
||||
try:
|
||||
cfg = model['Config']
|
||||
task_name = model['Results'][0]['Task']
|
||||
if task_name not in known_task:
|
||||
continue
|
||||
deploy_cfg = osp.join(MMDEPLOY_PATH, DEPLOY_CFG[task_name])
|
||||
model_cfg = osp.join(args.root_path, codebase, cfg)
|
||||
extract_one_model(deploy_cfg, model_cfg, args)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
global ELENA_BIN
|
||||
elena_path = osp.abspath(
|
||||
os.path.join(MMDEPLOY_PATH, 'third_party', 'CVFusion', 'build',
|
||||
'examples', 'MMDeploy', 'OpFuse'))
|
||||
if osp.exists(elena_path):
|
||||
ELENA_BIN = elena_path
|
||||
|
||||
for cb in CODEBASE:
|
||||
if not os.path.exists(osp.join(args.root_path, cb)):
|
||||
logger.warning(f'skip codebase {cb} because it isn\'t exists.')
|
||||
continue
|
||||
metafile_pattern = osp.join(args.root_path, cb, 'configs', '**/*.yml')
|
||||
metafiles = glob(metafile_pattern, recursive=True)
|
||||
for metafile in metafiles:
|
||||
extract_one_metafile(metafile, cb, args)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue