mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
* sync SDK changes according to performance benchmarks * fix end-of-file lint * fix clang-format issue * fix clang-format by adding 'clang-format off' * remove useless casts * remove 'data' argument of 'operator()' * change 'Tensor2Img' to 'TensorToImg' according to spec * correct tensor's name according spec Co-authored-by: lvhan028 <lvhan_028@163.com>
68 lines
2.0 KiB
C++
68 lines
2.0 KiB
C++
// Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
#include <numeric>
|
|
|
|
#include "codebase/mmcls/mmcls.h"
|
|
#include "core/tensor.h"
|
|
#include "core/utils/formatter.h"
|
|
#include "experimental/module_adapter.h"
|
|
|
|
using std::vector;
|
|
|
|
namespace mmdeploy::mmcls {
|
|
|
|
class LinearClsHead : public MMClassification {
|
|
public:
|
|
explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) {
|
|
if (cfg.contains("params")) {
|
|
topk_ = cfg["params"].value("topk", 1);
|
|
if (topk_ <= 0) {
|
|
ERROR("'topk' should be greater than 0, but got '{}'", topk_);
|
|
throw_exception(eInvalidArgument);
|
|
}
|
|
}
|
|
}
|
|
|
|
Result<Value> operator()(const Value& infer_res) {
|
|
DEBUG("infer_res: {}", infer_res);
|
|
auto output_tensor = infer_res["output"].get<Tensor>();
|
|
assert(output_tensor.shape().size() >= 2);
|
|
auto class_num = (int)output_tensor.shape()[1];
|
|
|
|
if (output_tensor.device().is_host()) {
|
|
vector<float> scores(output_tensor.data<float>(),
|
|
output_tensor.data<float>() + output_tensor.size());
|
|
OUTCOME_TRY(stream().Wait());
|
|
return GetLabels(scores, class_num);
|
|
} else {
|
|
vector<float> scores(output_tensor.size());
|
|
OUTCOME_TRY(output_tensor.CopyTo(scores.data(), stream()));
|
|
OUTCOME_TRY(stream().Wait());
|
|
return GetLabels(scores, class_num);
|
|
}
|
|
}
|
|
|
|
private:
|
|
Value GetLabels(const vector<float>& scores, int class_num) const {
|
|
ClassifyOutput output;
|
|
output.labels.reserve(topk_);
|
|
std::vector<int> idx(class_num);
|
|
iota(begin(idx), end(idx), 0);
|
|
partial_sort(begin(idx), begin(idx) + topk_, end(idx),
|
|
[&](int i, int j) { return scores[i] > scores[j]; });
|
|
for (int i = 0; i < topk_; ++i) {
|
|
auto label = ClassifyOutput::Label{idx[i], scores[idx[i]]};
|
|
DEBUG("label_id: {}, score: {}", label.label_id, label.score);
|
|
output.labels.push_back(label);
|
|
}
|
|
return to_value(std::move(output));
|
|
}
|
|
|
|
private:
|
|
int topk_{1};
|
|
};
|
|
|
|
REGISTER_CODEBASE_COMPONENT(MMClassification, LinearClsHead);
|
|
|
|
} // namespace mmdeploy::mmcls
|