From c6cf781fa6b8e9e5e031bb500c9c8a720de11173 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 10 Mar 2023 10:28:45 +0800 Subject: [PATCH] [Enhancement][Cherry-Pick #1858] Add optional `softmax` in `LinearClsHead` (#1863) * add softmax in cls postprocess * minor (cherry picked from commit bcb93ead589e6103377b1f8d6cfdc8f5f24b9527) --- csrc/mmdeploy/codebase/mmcls/linear_cls.cpp | 29 ++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp index a29a9a7ff..3e1f171e3 100644 --- a/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp +++ b/csrc/mmdeploy/codebase/mmcls/linear_cls.cpp @@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification { public: explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) { if (cfg.contains("params")) { + softmax_ = cfg["params"].value("softmax", false); topk_ = cfg["params"].value("topk", 1); if (topk_ <= 0) { MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); @@ -47,16 +48,31 @@ class LinearClsHead : public MMClassification { private: Value GetLabels(const Tensor& scores, int class_num) const { auto scores_data = scores.data(); + auto topk = std::min(topk_, class_num); Labels output; - output.reserve(topk_); + output.reserve(topk); std::vector idx(class_num); iota(begin(idx), end(idx), 0); - partial_sort(begin(idx), begin(idx) + topk_, end(idx), + partial_sort(begin(idx), begin(idx) + topk, end(idx), [&](int i, int j) { return scores_data[i] > scores_data[j]; }); - for (int i = 0; i < topk_; ++i) { - auto label = Label{idx[i], scores_data[idx[i]]}; - MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); - output.push_back(label); + + auto sum_exp = 0.f; + std::vector exp_scores; + if (softmax_) { + exp_scores.reserve(class_num); + auto max_val = scores_data[idx[0]]; + for (int i = 0; i < class_num; ++i) { + sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val)); + } + } + for (int i = 0; i < topk; ++i) { + float score = 0.f; + if (softmax_) { + score = exp_scores[idx[i]] / sum_exp; + } else { + score = scores_data[idx[i]]; + } + output.push_back({idx[i], score}); } return to_value(std::move(output)); } @@ -64,6 +80,7 @@ class LinearClsHead : public MMClassification { private: static constexpr const auto kHost = Device{0}; + bool softmax_{false}; int topk_{1}; };