[Fix] fix cls head in SDK (#1420)

* fix cls head

* resolve comments
pull/1427/head
AllentDan 2022-11-24 14:15:34 +08:00 committed by GitHub
parent de96f51231
commit 301035a06f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions

View File

@ -47,13 +47,14 @@ class LinearClsHead : public MMClassification {
private:
Value GetLabels(const Tensor& scores, int class_num) const {
auto scores_data = scores.data<float>();
auto topk = std::min(topk_, class_num);
Labels output;
output.reserve(topk_);
output.reserve(topk);
std::vector<int> idx(class_num);
iota(begin(idx), end(idx), 0);
partial_sort(begin(idx), begin(idx) + topk_, end(idx),
partial_sort(begin(idx), begin(idx) + topk, end(idx),
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
for (int i = 0; i < topk_; ++i) {
for (int i = 0; i < topk; ++i) {
auto label = Label{idx[i], scores_data[idx[i]]};
MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score);
output.push_back(label);