[Refactor] Support batch inference with shape clustering (#1733)

* refactor `NetModule`

* name

* fix sorting

* fix indices

(cherry picked from commit f5a05b52191e19d67225f332c19dfd53dac55843)
This commit is contained in:
Li Zhang 2023-02-08 20:36:29 +08:00 committed by GitHub
parent bc1b6440cd
commit a8f8c4febe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,9 @@
// Copyright (c) OpenMMLab. All rights reserved. // Copyright (c) OpenMMLab. All rights reserved.
#include "net_module.h" #include "mmdeploy/net/net_module.h"
#include <algorithm>
#include <numeric>
#include <thread> #include <thread>
#include "mmdeploy/archive/value_archive.h" #include "mmdeploy/archive/value_archive.h"
@ -31,6 +33,11 @@ struct NetModule::Impl {
is_profiling_ = true; is_profiling_ = true;
} }
auto model = context["model"].get<Model>(); auto model = context["model"].get<Model>();
for (const auto& meta : model.meta().models) {
if (meta.name == name) {
max_batch_size_ = meta.batch_size;
}
}
OUTCOME_TRY(auto config, model.GetModelConfig(name)); OUTCOME_TRY(auto config, model.GetModelConfig(name));
device_ = context.value("device", Device{"cpu"}); device_ = context.value("device", Device{"cpu"});
stream_ = context.value("stream", Stream::GetDefault(device_)); stream_ = context.value("stream", Stream::GetDefault(device_));
@ -78,7 +85,7 @@ struct NetModule::Impl {
return success(); return success();
} }
Result<TensorShape> InferInputShape(const vector<Tensor>& input) { static Result<TensorShape> InferBatchShape(const vector<Tensor>& input) {
auto batch_size = input.size(); auto batch_size = input.size();
auto& exemplar = input.front(); auto& exemplar = input.front();
auto shape = exemplar.shape(); auto shape = exemplar.shape();
@ -86,13 +93,13 @@ struct NetModule::Impl {
return shape; return shape;
} }
if (shape[0] != 1) { if (shape[0] != 1) {
MMDEPLOY_ERROR("unsupported shape for batch assemble: {}", shape); MMDEPLOY_WARN("unsupported shape for batch assemble: {}", shape);
return Status(eNotSupported); return Status(eNotSupported);
} }
for (int i = 1; i < input.size(); ++i) { for (int i = 1; i < input.size(); ++i) {
auto& sample = input[i]; auto& sample = input[i];
if (sample.shape() != shape) { if (sample.shape() != shape) {
MMDEPLOY_ERROR("shapes are not consistent across the batch"); MMDEPLOY_WARN("shapes are not consistent across the batch");
return Status(eNotSupported); return Status(eNotSupported);
} }
} }
@ -100,90 +107,175 @@ struct NetModule::Impl {
return shape; return shape;
} }
Result<vector<TensorShape> > InferInputShape(const vector<vector<Tensor> >& inputs) { static Result<vector<TensorShape>> InferBatchShape(const vector<vector<Tensor>>& inputs) {
vector<TensorShape> shapes; vector<TensorShape> shapes;
shapes.reserve(inputs.size()); shapes.reserve(inputs.size());
for (const auto& input : inputs) { for (const auto& input : inputs) {
OUTCOME_TRY(auto shape, InferInputShape(input)); OUTCOME_TRY(auto shape, InferBatchShape(input));
shapes.push_back(std::move(shape)); shapes.push_back(std::move(shape));
} }
return shapes; return shapes;
} }
Result<std::vector<Output> > Forward(const std::vector<Input>& input) { Result<vector<vector<Tensor>>> CollectInputTensors(const vector<Input>& inputs) {
// auto t0 = std::chrono::high_resolution_clock::now(); vector<vector<Tensor>> input_samples;
//
auto batch_size = static_cast<int>(input.size());
std::vector<std::vector<Tensor> > input_samples;
input_samples.reserve(inputs_.size()); input_samples.reserve(inputs_.size());
for (const auto& t : inputs_) { for (const auto& t : inputs_) {
auto name = input_mapping_.at(t.name()); auto name = input_mapping_.at(t.name());
std::vector<Tensor> tmp; auto& tmp = input_samples.emplace_back();
tmp.reserve(input.size()); for (const auto& sample : inputs) {
for (int i = 0; i < input.size(); ++i) {
auto& sample = input[i];
if (auto it = sample.find(name); it != sample.end()) { if (auto it = sample.find(name); it != sample.end()) {
tmp.push_back(it->second); tmp.push_back(it->second);
} else { } else {
MMDEPLOY_ERROR("sample {} missing key {}", i, name); MMDEPLOY_ERROR("sample {} missing key {}", &sample - inputs.data(), name);
return Status(eInvalidArgument); return Status(eInvalidArgument);
} }
} }
input_samples.push_back(std::move(tmp));
} }
return input_samples;
}
// 1. calculate input shape void SaveBatch(vector<vector<Tensor>> samples, vector<int> indices,
OUTCOME_TRY(auto input_shapes, InferInputShape(input_samples)); vector<vector<vector<Tensor>>>& batch_tensors,
vector<vector<TensorShape>>& batch_shapes,
// 2. call backend's reshape vector<vector<int>>& batch_sample_idxs) const {
OUTCOME_TRY(net_->Reshape(input_shapes)); if (auto maybe_batch_shape = InferBatchShape(samples)) {
batch_shapes.push_back(maybe_batch_shape.value());
// 3. fill input tensor batch_tensors.push_back(std::move(samples));
for (int i = 0; i < inputs_.size(); ++i) { batch_sample_idxs.push_back(std::move(indices));
auto& src = input_samples[i]; } else {
auto& dst = inputs_[i]; // cannot assemble batch, do it one by one
if (dst.shape() != input_shapes[i]) { for (int k = 0; k < indices.size(); ++k) {
MMDEPLOY_ERROR("inconsistent input shape, expect {}, got {}", input_shapes[i], dst.shape()); auto& shapes = batch_shapes.emplace_back();
return Status(eFail); auto& batch = batch_tensors.emplace_back(inputs_.size());
} batch_sample_idxs.push_back({indices[k]});
if (src.size() > 1) { for (int j = 0; j < inputs_.size(); ++j) {
for (int j = 0; j < src.size(); ++j) { shapes.push_back(samples[j][k].shape());
auto slice = dst.Slice(j); batch[j].push_back(std::move(samples[j][k]));
OUTCOME_TRY(src[j].CopyTo(slice, stream_));
} }
} else { }
OUTCOME_TRY(src[0].CopyTo(dst, stream_)); }
}
void SamplesToBatches(const vector<vector<Tensor>>& input_samples, size_t n_samples,
vector<vector<vector<Tensor>>>& batch_tensors,
vector<vector<TensorShape>>& batch_shapes,
vector<vector<int>>& batch_sample_idxs) const {
// concat all shapes in samples to make comparison easier
vector<vector<int64_t>> concat_shapes;
concat_shapes.reserve(n_samples);
for (size_t i = 0; i < n_samples; ++i) {
auto& shape = concat_shapes.emplace_back();
for (const auto& input : input_samples) {
shape.insert(shape.end(), input[i].shape().begin(), input[i].shape().end());
} }
} }
// 5. forward // cluster samples by concatenated shapes
OUTCOME_TRY(net_->Forward()); vector<int> shape_idxs(concat_shapes.size());
std::iota(shape_idxs.begin(), shape_idxs.end(), 0);
std::sort(shape_idxs.begin(), shape_idxs.end(),
[&concat_shapes](int i, int j) { return concat_shapes[i] < concat_shapes[j]; });
shape_idxs.erase(std::unique(shape_idxs.begin(), shape_idxs.end(),
[&concat_shapes](int i, int j) {
return concat_shapes[i] == concat_shapes[j];
}),
shape_idxs.end());
vector<Output> output(batch_size); // generate batches of samples with equal shapes, limit the batch size by max_batch_size_
for (const auto& t : outputs_) { for (const auto ref_shape_idx : shape_idxs) {
auto name = output_mapping_.at(t.name()); const auto& ref_shape = concat_shapes[ref_shape_idx];
auto desc = t.desc(); vector<vector<Tensor>> samples(inputs_.size());
desc.device = device_; vector<int> indices;
Tensor tmp(desc); for (size_t i = 0; i < concat_shapes.size(); ++i) {
if (tmp.size()) { if (concat_shapes[i] == ref_shape) {
OUTCOME_TRY(t.CopyTo(tmp, stream_)); for (size_t j = 0; j < inputs_.size(); ++j) {
} else { samples[j].push_back(input_samples[j][i]);
MMDEPLOY_WARN("copy skipped due to zero sized tensor"); }
} indices.push_back(static_cast<int>(i));
if (output.size() > 1) { if (indices.size() == max_batch_size_) {
for (int i = 0; i < output.size(); ++i) { SaveBatch(std::move(samples), std::move(indices), batch_tensors, batch_shapes,
output[i].emplace(name, tmp.Slice(i)); batch_sample_idxs);
samples = vector<vector<Tensor>>(inputs_.size());
indices = {};
}
} }
} else { }
output[0].emplace(name, std::move(tmp)); if (!indices.empty()) {
SaveBatch(std::move(samples), std::move(indices), batch_tensors, batch_shapes,
batch_sample_idxs);
} }
} }
}
Result<vector<Output>> Forward(const vector<Input>& inputs) {
OUTCOME_TRY(auto input_samples, CollectInputTensors(inputs));
vector<vector<vector<Tensor>>> batch_tensors;
vector<vector<TensorShape>> batch_shapes;
vector<vector<int>> batch_sample_indices;
SamplesToBatches(input_samples, inputs.size(), batch_tensors, batch_shapes,
batch_sample_indices);
vector<Output> outputs(inputs.size());
for (size_t i = 0; i < batch_tensors.size(); ++i) {
OUTCOME_TRY(net_->Reshape(batch_shapes[i]));
OUTCOME_TRY(CopyInputTensors(batch_tensors[i], batch_shapes[i]));
OUTCOME_TRY(net_->Forward());
OUTCOME_TRY(CopyOutputTensors(batch_sample_indices[i], outputs));
if (i + 1 < batch_tensors.size()) { // sync if not the last batch
OUTCOME_TRY(stream_.Wait());
}
}
if (is_profiling_) { if (is_profiling_) {
OUTCOME_TRY(stream_.Wait()); OUTCOME_TRY(stream_.Wait());
} }
return output; return outputs;
}
Result<void> CopyInputTensors(const vector<vector<Tensor>>& batch,
const vector<TensorShape>& shapes) const {
for (int i = 0; i < inputs_.size(); ++i) {
auto& src = batch[i];
auto& dst = inputs_[i];
if (dst.shape() != shapes[i]) {
MMDEPLOY_ERROR("inconsistent input shape, expect {}, got {}", shapes[i], dst.shape());
return Status(eFail);
}
if (src.size() > 1) {
for (int j = 0; j < src.size(); ++j) {
OUTCOME_TRY(dst.Slice(j).CopyFrom(src[j], stream_));
}
} else {
OUTCOME_TRY(src.front().CopyTo(dst, stream_));
}
}
return success();
}
Result<void> CopyOutputTensors(const vector<int>& indices, vector<Output>& outputs) {
for (const auto& output : outputs_) {
auto name = output_mapping_.at(output.name());
auto desc = output.desc();
desc.device = device_;
Tensor tmp(desc);
if (tmp.size()) {
OUTCOME_TRY(output.CopyTo(tmp, stream_));
} else {
MMDEPLOY_WARN("copy skipped due to zero sized tensor");
}
if (indices.size() > 1) {
for (int i = 0; i < indices.size(); ++i) {
outputs[indices[i]].emplace(name, tmp.Slice(i));
}
} else {
outputs[indices.front()].emplace(name, std::move(tmp));
}
}
return success();
} }
Device device_; Device device_;
@ -195,6 +287,7 @@ struct NetModule::Impl {
std::map<std::string, std::string> input_mapping_; std::map<std::string, std::string> input_mapping_;
// outer scope to model output names // outer scope to model output names
std::map<std::string, std::string> output_mapping_; std::map<std::string, std::string> output_mapping_;
int max_batch_size_{1};
bool is_profiling_{false}; bool is_profiling_{false};
}; };