// Copyright (c) OpenMMLab. All rights reserved. #ifndef MMDEPLOY_REGISTRY_H #define MMDEPLOY_REGISTRY_H #include #include #include #include #include #include "value.h" namespace mmdeploy { namespace detail { template struct get_return_type { using type = std::unique_ptr; }; template struct get_return_type> { using type = typename EntryType::type; }; template using get_return_type_t = typename get_return_type::type; } // namespace detail template class Creator { public: using ReturnType = detail::get_return_type_t; public: virtual ~Creator() = default; virtual const char *GetName() const = 0; virtual int GetVersion() const = 0; virtual ReturnType Create(const Value &args) = 0; }; template class Registry { public: static Registry &Get() { static Registry registry; return registry; } bool AddCreator(Creator &creator) { auto key = creator.GetName(); if (entries_.find(key) == entries_.end()) { entries_.insert(std::make_pair(key, &creator)); return true; } for (auto iter = entries_.lower_bound(key); iter != entries_.upper_bound(key); ++iter) { if (iter->second->GetVersion() == creator.GetVersion()) { return false; } } entries_.insert(std::make_pair(key, &creator)); return true; } Creator *GetCreator(const std::string &type, int version = 0) { auto iter = entries_.find(type); if (iter == entries_.end()) { return nullptr; } if (0 == version) { return iter->second; } for (auto iter = entries_.lower_bound(type); iter != entries_.upper_bound(type); ++iter) { if (iter->second->GetVersion() == version) { return iter->second; } } return nullptr; } std::vector ListCreators() { std::vector keys; for (const auto &[key, _] : entries_) { keys.push_back(key); } return keys; } private: Registry() = default; private: std::multimap *> entries_; }; template class Registerer { public: Registerer() { Registry::Get().AddCreator(inst_); } private: CreatorType inst_; }; } // namespace mmdeploy #define REGISTER_MODULE(EntryType, CreatorType) \ static ::mmdeploy::Registerer g_register_##EntryType##_##CreatorType{}; #define DECLARE_AND_REGISTER_MODULE(base_type, module_name, version) \ class module_name##Creator : public ::mmdeploy::Creator { \ public: \ module_name##Creator() = default; \ ~module_name##Creator() = default; \ const char *GetName() const override { return #module_name; } \ int GetVersion() const override { return version; } \ \ std::unique_ptr Create(const Value &value) override { \ return std::make_unique(value); \ } \ }; \ REGISTER_MODULE(base_type, module_name##Creator); #endif // MMDEPLOY_REGISTRY_H