// Copyright (c) OpenMMLab. All rights reserved. #include #include #include "archive/json_archive.h" #include "core/logger.h" #include "core/model.h" #include "core/model_impl.h" #include "zip.h" #if __GNUC__ >= 8 #include namespace fs = std::filesystem; #else #include namespace fs = std::experimental::filesystem; #endif using nlohmann::json; namespace mmdeploy { class ZipModelImpl : public ModelImpl { public: ~ZipModelImpl() override { if (zip_ != nullptr) { zip_close(zip_); } #if LIBZIP_VERSION_MAJOR >= 1 if (source_) { zip_source_close(source_); } #endif } // @brief load an sdk model, which HAS TO BE a zip file. // Meta file (i.e. deploy.json) will be extracted and parsed from the zip file // @param sdk_model_path path of sdk model file, in zip format Result Init(const std::string& model_path) override { int ret = 0; zip_ = zip_open(model_path.c_str(), 0, &ret); if (ret != 0) { INFO("open zip file {} failed, ret {}", model_path.c_str(), ret); return Status(eInvalidArgument); } INFO("open sdk model file {} successfully", model_path.c_str()); return InitZip(); } Result Init(const void* buffer, size_t size) override { #if LIBZIP_VERSION_MAJOR >= 1 zip_error_t error{}; source_ = zip_source_buffer_create(buffer, size, 0, &error); if (zip_error_code_zip(&error) != ZIP_ER_OK) { return Status(eFail); } zip_ = zip_open_from_source(source_, ZIP_RDONLY, &error); if (zip_error_code_zip(&error) != ZIP_ER_OK) { return Status(eFail); } return InitZip(); #else return Status(eNotSupported); #endif } Result ReadFile(const std::string& file_path) const override { int ret = 0; int index = -1; auto iter = file_index_.find(file_path); if (iter == file_index_.end()) { ERROR("cannot find file {} under dir {}", file_path.c_str(), root_dir_.c_str()); return Status(eFail); } index = iter->second; struct zip_file* pzip = zip_fopen_index(zip_, index, 0); if (nullptr == pzip) { ERROR("read file {} in zip file failed, whose index is {}", file_path.c_str(), index); return Status(eFail); } struct zip_stat stat {}; if ((ret = zip_stat_index(zip_, index, 0, &stat)) < 0) { ERROR("get stat of file {} error, ret {}", file_path.c_str(), ret); return Status(eFail); } DEBUG("file size {}", (int)stat.size); std::vector buf(stat.size); if ((ret = zip_fread(pzip, buf.data(), stat.size)) < 0) { ERROR("read data of file {} error, ret {}", file_path.c_str(), ret); return Status(eFail); } return std::string(buf.begin(), buf.end()); } Result ReadMeta() const override { OUTCOME_TRY(auto deploy_json, ReadFile("deploy.json")); try { deploy_meta_info_t meta; from_json(json::parse(deploy_json), meta); return meta; } catch (std::exception& e) { ERROR("exception happened: {}", e.what()); return Status(eFail); } } private: Result InitZip() { int files = zip_get_num_files(zip_); INFO("there are {} files in sdk model file", files); if (files == 0) { return Status(eFail); } for (int i = 0; i < files; ++i) { struct zip_stat stat; zip_stat_init(&stat); zip_stat_index(zip_, i, 0, &stat); fs::path path(stat.name); auto file_name = path.filename().string(); if (file_name == ".") { DEBUG("{}-th file name is: {}, which is a directory", i, stat.name); } else { DEBUG("{}-th file name is: {}, which is a file", i, stat.name); file_index_[file_name] = i; } } return success(); } #if LIBZIP_VERSION_MAJOR >= 1 struct zip_source* source_{}; #endif struct zip* zip_{}; // root directory in zip file std::string root_dir_; // a map between file path and its index in zip file std::map file_index_; }; class ZipModelImplRegister { public: ZipModelImplRegister() { (void)ModelRegistry::Get().Register("ZipModel", []() -> std::unique_ptr { return std::make_unique(); }); } }; static ZipModelImplRegister folder_model_register; } // namespace mmdeploy