mmdeploy/csrc/net/net_module.cpp

238 lines
7.0 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "net_module.h"
#include <thread>
#include "archive/value_archive.h"
#include "core/logger.h"
#include "core/model.h"
#include "core/module.h"
#include "core/net.h"
#include "core/registry.h"
#include "core/utils/formatter.h"
#include "core/utils/scope_counter.h"
#include "experimental/module_adapter.h"
using std::string;
using std::vector;
namespace mmdeploy {
struct NetModule::Impl {
using Input = std::map<std::string, Tensor>;
using Output = std::map<std::string, Tensor>;
explicit Impl(const Value& args) {
DEBUG("Net Module cfg: {}", args);
auto init = [&]() -> Result<void> {
auto name = args["name"].get<std::string>();
auto& context = args["context"];
auto model = context["model"].get<Model>();
OUTCOME_TRY(auto config, model.GetModelConfig(name));
device_ = context.value("device", Device{"cpu"});
stream_ = context.value("stream", Stream::GetDefault(device_));
auto creator = Registry<Net>::Get().GetCreator(config.backend);
if (!creator) {
ERROR("Net backend not found: {}", config.backend);
return Status(eEntryNotFound);
}
auto net_cfg = args;
net_cfg["context"].update({{"device", device_}, {"stream", stream_}});
net_ = creator->Create(net_cfg);
if (!net_) {
return Status(eFail);
}
OUTCOME_TRY(InitializeInputTensors(args));
OUTCOME_TRY(InitializeOutputTensors(args));
return success();
};
init().value();
}
Result<void> InitializeInputTensors(const Value& args) {
auto inputs = args.value<Value>("input_map", ValueType::kObject);
for (auto it = inputs.begin(); it != inputs.end(); ++it) {
input_mapping_.insert({(*it).get<std::string>(), it.key()});
}
OUTCOME_TRY(inputs_, net_->GetInputTensors());
for (const auto& t : inputs_) {
input_mapping_.insert({t.name(), t.name()});
}
return success();
}
Result<void> InitializeOutputTensors(const Value& args) {
auto outputs = args.value<Value>("output_map", ValueType::kObject);
for (auto it = outputs.begin(); it != outputs.end(); ++it) {
output_mapping_.insert({(*it).get<std::string>(), it.key()});
}
OUTCOME_TRY(outputs_, net_->GetOutputTensors());
for (const auto& t : outputs_) {
output_mapping_.insert({t.name(), t.name()});
}
return success();
}
Result<TensorShape> InferInputShape(const vector<Tensor>& input) {
auto batch_size = input.size();
auto& exemplar = input.front();
auto shape = exemplar.shape();
if (batch_size == 1) {
return shape;
}
if (shape[0] != 1) {
ERROR("unsupported shape for batch assemble: {}", shape);
return Status(eNotSupported);
}
for (int i = 1; i < input.size(); ++i) {
auto& sample = input[i];
if (sample.shape() != shape) {
ERROR("shapes are not consistent across the batch");
return Status(eNotSupported);
}
}
shape[0] = static_cast<int64_t>(batch_size);
return shape;
}
Result<vector<TensorShape> > InferInputShape(const vector<vector<Tensor> >& inputs) {
vector<TensorShape> shapes;
shapes.reserve(inputs.size());
for (const auto& input : inputs) {
OUTCOME_TRY(auto shape, InferInputShape(input));
shapes.push_back(std::move(shape));
}
return shapes;
}
Result<std::vector<Output> > Forward(const std::vector<Input>& input) {
// auto t0 = std::chrono::high_resolution_clock::now();
//
auto batch_size = static_cast<int>(input.size());
std::vector<std::vector<Tensor> > input_samples;
input_samples.reserve(inputs_.size());
for (const auto& t : inputs_) {
auto name = input_mapping_.at(t.name());
std::vector<Tensor> tmp;
tmp.reserve(input.size());
for (int i = 0; i < input.size(); ++i) {
auto& sample = input[i];
if (auto it = sample.find(name); it != sample.end()) {
tmp.push_back(it->second);
} else {
ERROR("sample {} missing key {}", i, name);
return Status(eInvalidArgument);
}
}
input_samples.push_back(std::move(tmp));
}
// 1. calculate input shape
OUTCOME_TRY(auto input_shapes, InferInputShape(input_samples));
// 2. call backend's reshape
OUTCOME_TRY(net_->Reshape(input_shapes));
// 3. fill input tensor
for (int i = 0; i < inputs_.size(); ++i) {
auto& src = input_samples[i];
auto& dst = inputs_[i];
if (dst.shape() != input_shapes[i]) {
ERROR("inconsistent input shape, expect {}, got {}", input_shapes[i], dst.shape());
return Status(eFail);
}
if (src.size() > 1) {
for (int j = 0; j < src.size(); ++j) {
auto slice = dst.Slice(j);
OUTCOME_TRY(src[j].CopyTo(slice, stream_));
}
} else {
OUTCOME_TRY(src[0].CopyTo(dst, stream_));
}
}
// 5. forward
OUTCOME_TRY(net_->Forward());
vector<Output> output(batch_size);
for (const auto& t : outputs_) {
auto name = output_mapping_.at(t.name());
auto desc = t.desc();
desc.device = device_;
Tensor tmp(desc);
if (tmp.size()) {
OUTCOME_TRY(t.CopyTo(tmp, stream_));
} else {
WARN("copy skipped due to zero sized tensor");
}
if (output.size() > 1) {
for (int i = 0; i < output.size(); ++i) {
output[i].emplace(name, tmp.Slice(i));
}
} else {
output[0].emplace(name, std::move(tmp));
}
}
return output;
}
Device device_;
Stream stream_;
std::unique_ptr<Net> net_;
Span<Tensor> inputs_;
Span<Tensor> outputs_;
// outer scope to model input names
std::map<std::string, std::string> input_mapping_;
// outer scope to model output names
std::map<std::string, std::string> output_mapping_;
};
NetModule::~NetModule() = default;
NetModule::NetModule(const Value& args) : impl_(std::make_unique<Impl>(args)) {}
Result<Value> NetModule::operator()(const Value& input) {
auto filter = [](const Value& sample) {
Impl::Input tensors;
for (auto it = sample.begin(); it != sample.end(); ++it) {
if (it->is_any<Tensor>()) {
tensors.insert({it.key(), it->get<Tensor>()});
}
}
return tensors;
};
std::vector<Impl::Input> batch;
if (input.is_array()) {
batch.reserve(input.size());
for (const auto& sample : input) {
batch.push_back(filter(sample));
}
} else if (input.is_object()) {
batch.push_back(filter(input));
} else {
return Status(eNotSupported);
}
OUTCOME_TRY(auto batch_output, impl_->Forward(batch));
if (input.is_array()) {
return to_value(batch_output);
} else {
return to_value(batch_output.at(0));
}
}
class NetModuleCreator : public Creator<Module> {
public:
const char* GetName() const override { return "Net"; }
int GetVersion() const override { return 0; }
std::unique_ptr<Module> Create(const Value& value) override {
return CreateTask(NetModule{value});
}
};
REGISTER_MODULE(Module, NetModuleCreator);
} // namespace mmdeploy