[Enhancement] Add optional `softmax` in `LinearClsHead` (#1858)
* add softmax in cls postprocess * minorpull/1864/head
parent
f69c636a2e
commit
bcb93ead58
|
@ -18,6 +18,7 @@ class LinearClsHead : public MMClassification {
|
|||
public:
|
||||
explicit LinearClsHead(const Value& cfg) : MMClassification(cfg) {
|
||||
if (cfg.contains("params")) {
|
||||
softmax_ = cfg["params"].value("softmax", false);
|
||||
topk_ = cfg["params"].value("topk", 1);
|
||||
if (topk_ <= 0) {
|
||||
MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_);
|
||||
|
@ -54,10 +55,24 @@ class LinearClsHead : public MMClassification {
|
|||
iota(begin(idx), end(idx), 0);
|
||||
partial_sort(begin(idx), begin(idx) + topk, end(idx),
|
||||
[&](int i, int j) { return scores_data[i] > scores_data[j]; });
|
||||
|
||||
auto sum_exp = 0.f;
|
||||
std::vector<float> exp_scores;
|
||||
if (softmax_) {
|
||||
exp_scores.reserve(class_num);
|
||||
auto max_val = scores_data[idx[0]];
|
||||
for (int i = 0; i < class_num; ++i) {
|
||||
sum_exp += exp_scores.emplace_back(std::exp(scores_data[i] - max_val));
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < topk; ++i) {
|
||||
auto label = Label{idx[i], scores_data[idx[i]]};
|
||||
MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score);
|
||||
output.push_back(label);
|
||||
float score = 0.f;
|
||||
if (softmax_) {
|
||||
score = exp_scores[idx[i]] / sum_exp;
|
||||
} else {
|
||||
score = scores_data[idx[i]];
|
||||
}
|
||||
output.push_back({idx[i], score});
|
||||
}
|
||||
return to_value(std::move(output));
|
||||
}
|
||||
|
@ -65,6 +80,7 @@ class LinearClsHead : public MMClassification {
|
|||
private:
|
||||
static constexpr const auto kHost = Device{0};
|
||||
|
||||
bool softmax_{false};
|
||||
int topk_{1};
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue