mmdeploy/csrc/model/zip_model_impl.cpp
lvhan028 36124f6205
Merge sdk (#251)
* check in cmake

* move backend_ops to csrc/backend_ops

* check in preprocess, model, some codebase and their c-apis

* check in CMakeLists.txt

* check in parts of test_csrc

* commit everything else

* add readme

* update core's BUILD_INTERFACE directory

* skip codespell on third_party

* update trt_net and ort_net's CMakeLists

* ignore clion's build directory

* check in pybind11

* add onnx.proto. Remove MMDeploy's dependency on ncnn's source code

* export MMDeployTargets only when MMDEPLOY_BUILD_SDK is ON

* remove useless message

* target include directory is wrong

* change target name from mmdeploy_ppl_net to mmdeploy_pplnn_net

* skip install directory

* update project's cmake

* remove useless code

* set CMAKE_BUILD_TYPE to Release by force if it isn't set by user

* update custom ops CMakeLists

* pass object target's source lists

* fix lint end-of-file

* fix lint: trailing whitespace

* fix codespell hook

* remove bicubic_interpolate to csrc/backend_ops/

* set MMDEPLOY_BUILD_SDK OFF

* change custom ops build command

* add spdlog installation command

* update docs on how to checkout pybind11

* move bicubic_interpolate to backend_ops/tensorrt directory

* remove useless code

* correct cmake

* fix typo

* fix typo

* fix install directory

* correct sdk's readme

* set cub dir when cuda version < 11.0

* change directory where clang-format will apply to

* fix build command

* add .clang-format

* change clang-format style from google to file

* reformat csrc/backend_ops

* format sdk's code

* turn off clang-format for some files

* add -Xcompiler=-fno-gnu-unique

* fix trt topk initialize

* check in config for sdk demo

* update cmake script and csrc's readme

* correct config's path

* add cuda include directory, otherwise compile failed in case of tensorrt8.2

* clang-format onnx2ncnn.cpp

Co-authored-by: zhangli <lzhang329@gmail.com>
Co-authored-by: grimoire <yaoqian@sensetime.com>
2021-12-07 10:57:55 +08:00

142 lines
4.0 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (c) OpenMMLab. All rights reserved.
#include <string.h>
#include <fstream>
#include <map>
#include "archive/json_archive.h"
#include "core/logger.h"
#include "core/model.h"
#include "core/model_impl.h"
#include "zip.h"
using nlohmann::json;
namespace mmdeploy {
class ZipModelImpl : public ModelImpl {
public:
~ZipModelImpl() override {
if (zip_ != nullptr) {
zip_close(zip_);
}
#if LIBZIP_VERSION_MAJOR >= 1
if (source_) {
zip_source_close(source_);
}
#endif
}
// @brief load an sdk model, which HAS TO BE a zip file.
// Meta file (i.e. deploy.json) will be extracted and parsed from the zip file
// @param sdk_model_path path of sdk model file, in zip format
Result<void> Init(const std::string& model_path) override {
int ret = 0;
zip_ = zip_open(model_path.c_str(), 0, &ret);
if (ret != 0) {
INFO("open zip file {} failed, ret {}", model_path.c_str(), ret);
return Status(eInvalidArgument);
}
INFO("open sdk model file {} successfully", model_path.c_str());
return InitZip();
}
Result<void> Init(const void* buffer, size_t size) override {
#if LIBZIP_VERSION_MAJOR >= 1
zip_error_t error{};
source_ = zip_source_buffer_create(buffer, size, 0, &error);
if (zip_error_code_zip(&error) != ZIP_ER_OK) {
return Status(eFail);
}
zip_ = zip_open_from_source(source_, ZIP_RDONLY, &error);
if (zip_error_code_zip(&error) != ZIP_ER_OK) {
return Status(eFail);
}
return InitZip();
#else
return Status(eNotSupported);
#endif
}
Result<std::string> ReadFile(const std::string& file_path) const override {
int ret = 0;
int index = -1;
auto iter = file_index_.find(file_path);
if (iter == file_index_.end()) {
ERROR("cannot find file {} under dir {}", file_path.c_str(), root_dir_.c_str());
return Status(eFail);
}
index = iter->second;
struct zip_file* pzip = zip_fopen_index(zip_, index, 0);
if (nullptr == pzip) {
ERROR("read file {} in zip file failed, whose index is {}", file_path.c_str(), index);
return Status(eFail);
}
struct zip_stat stat {};
if ((ret = zip_stat_index(zip_, index, 0, &stat)) < 0) {
ERROR("get stat of file {} error, ret {}", file_path.c_str(), ret);
return Status(eFail);
}
DEBUG("file size {}", (int)stat.size);
std::vector<char> buf(stat.size);
if ((ret = zip_fread(pzip, buf.data(), stat.size)) < 0) {
ERROR("read data of file {} error, ret {}", file_path.c_str(), ret);
return Status(eFail);
}
return std::string(buf.begin(), buf.end());
}
Result<deploy_meta_info_t> ReadMeta() const override {
OUTCOME_TRY(auto deploy_json, ReadFile("deploy.json"));
try {
deploy_meta_info_t meta;
from_json(json::parse(deploy_json), meta);
return meta;
} catch (std::exception& e) {
ERROR("exception happened: {}", e.what());
return Status(eFail);
}
}
private:
Result<void> InitZip() {
int files = zip_get_num_files(zip_);
INFO("there are {} files in sdk model file", files);
for (int i = 0; i < files; ++i) {
struct zip_stat stat;
zip_stat_init(&stat);
zip_stat_index(zip_, i, 0, &stat);
if (stat.name[strlen(stat.name) - 1] == '/') {
DEBUG("{}-th file name is: {} which is a directory", i, stat.name);
} else {
DEBUG("{}-th file name is: {} which is a file", i, stat.name);
file_index_[stat.name] = i;
}
}
return success();
}
#if LIBZIP_VERSION_MAJOR >= 1
struct zip_source* source_{};
#endif
struct zip* zip_{};
// root directory in zip file
std::string root_dir_;
// a map between file path and its index in zip file
std::map<std::string, int> file_index_;
};
class ZipModelImplRegister {
public:
ZipModelImplRegister() {
(void)ModelRegistry::Get().Register("ZipModel", []() -> std::unique_ptr<ModelImpl> {
return std::make_unique<ZipModelImpl>();
});
}
};
static ZipModelImplRegister folder_model_register;
} // namespace mmdeploy