From 3b62b21fac61a7ffab6795cf3cb5447a4d3d9327 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 17 May 2023 15:47:05 +0800 Subject: [PATCH] add attention convertor (#2064) --- .../codebase/mmocr/attention_convertor.cpp | 207 ++++++++++++++++++ .../codebase/mmocr/deploy/text_recognition.py | 7 +- 2 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 csrc/mmdeploy/codebase/mmocr/attention_convertor.cpp diff --git a/csrc/mmdeploy/codebase/mmocr/attention_convertor.cpp b/csrc/mmdeploy/codebase/mmocr/attention_convertor.cpp new file mode 100644 index 000000000..0af9e0acd --- /dev/null +++ b/csrc/mmdeploy/codebase/mmocr/attention_convertor.cpp @@ -0,0 +1,207 @@ +// Copyright (c) OpenMMLab. All rights reserved. + +#include +#include + +#include "mmdeploy/core/device.h" +#include "mmdeploy/core/model.h" +#include "mmdeploy/core/registry.h" +#include "mmdeploy/core/tensor.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "mmdeploy/core/utils/formatter.h" +#include "mmdeploy/core/value.h" +#include "mmdeploy/experimental/module_adapter.h" +#include "mmocr.h" + +namespace mmdeploy::mmocr { + +using std::string; +using std::vector; + +class AttnConvertor : public MMOCR { + public: + explicit AttnConvertor(const Value& cfg) : MMOCR(cfg) { + auto model = cfg["context"]["model"].get(); + if (!cfg.contains("params")) { + MMDEPLOY_ERROR("'params' is required, but it's not in the config"); + throw_exception(eInvalidArgument); + } + // BaseConverter + auto& _cfg = cfg["params"]; + if (_cfg.contains("dict_file")) { + auto filename = _cfg["dict_file"].get(); + auto content = model.ReadFile(filename).value(); + idx2char_ = SplitLines(content); + } else if (_cfg.contains("dict_list")) { + from_value(_cfg["dict_list"], idx2char_); + } else if (_cfg.contains("dict_type")) { + auto dict_type = _cfg["dict_type"].get(); + if (dict_type == "DICT36") { + idx2char_ = SplitChars(DICT36); + } else if (dict_type == "DICT90") { + idx2char_ = SplitChars(DICT90); + } else { + MMDEPLOY_ERROR("unknown dict_type: {}", dict_type); + throw_exception(eInvalidArgument); + } + } else { + MMDEPLOY_ERROR("either dict_file, dict_list or dict_type must be specified"); + throw_exception(eInvalidArgument); + } + // Update Dictionary + + bool with_start = _cfg.value("with_start", false); + bool with_end = _cfg.value("with_end", false); + bool same_start_end = _cfg.value("same_start_end", false); + bool with_padding = _cfg.value("with_padding", false); + bool with_unknown = _cfg.value("with_unknown", false); + if (with_start && with_end && same_start_end) { + idx2char_.emplace_back(""); + start_idx_ = static_cast(idx2char_.size()) - 1; + end_idx_ = start_idx_; + } else { + if (with_start) { + idx2char_.emplace_back(""); + start_idx_ = static_cast(idx2char_.size()) - 1; + } + if (with_end) { + idx2char_.emplace_back(""); + end_idx_ = static_cast(idx2char_.size()) - 1; + } + } + + if (with_padding) { + idx2char_.emplace_back(""); + padding_idx_ = static_cast(idx2char_.size()) - 1; + } + if (with_unknown) { + idx2char_.emplace_back(""); + unknown_idx_ = static_cast(idx2char_.size()) - 1; + } + + vector ignore_chars; + if (cfg.contains("ignore_chars")) { + for (int i = 0; i < cfg["ignore_chars"].size(); i++) + ignore_chars.emplace_back(cfg["ignore_chars"][i].get()); + } else { + ignore_chars.emplace_back("padding"); + } + std::map mapping_table = { + {"padding", padding_idx_}, {"end", end_idx_}, {"unknown", unknown_idx_}}; + for (int i = 0; i < ignore_chars.size(); i++) { + if (mapping_table.find(ignore_chars[i]) != mapping_table.end()) { + ignore_indexes_.emplace_back(mapping_table.at(ignore_chars[i])); + } + } + + model_ = model; + } + + Result operator()(const Value& _data, const Value& _prob) { + auto d_conf = _prob["output"].get(); + + if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(), + (int)d_conf.data_type()); + return Status(eNotSupported); + } + + OUTCOME_TRY(auto h_conf, MakeAvailableOnDevice(d_conf, Device{0}, stream())); + OUTCOME_TRY(stream().Wait()); + + auto data = h_conf.data(); + + auto shape = d_conf.shape(); + auto w = static_cast(shape[1]); + auto c = static_cast(shape[2]); + + float valid_ratio = 1; + if (_data["img_metas"].contains("valid_ratio")) { + valid_ratio = _data["img_metas"]["valid_ratio"].get(); + } + auto [indexes, scores] = Tensor2Idx(data, w, c, valid_ratio); + + auto text = Idx2Str(indexes); + MMDEPLOY_DEBUG("text: {}", text); + + TextRecognition output{text, scores}; + + return make_pointer(to_value(output)); + } + + std::pair, vector > Tensor2Idx(const float* data, int w, int c, + float valid_ratio) { + auto decode_len = w; + vector indexes; + indexes.reserve(decode_len); + vector scores; + scores.reserve(decode_len); + for (int t = 0; t < decode_len; ++t, data += c) { + vector prob(data, data + c); + auto iter = max_element(begin(prob), end(prob)); + auto index = static_cast(iter - begin(prob)); + if (index == end_idx_) break; + if (std::find(ignore_indexes_.begin(), ignore_indexes_.end(), index) == + ignore_indexes_.end()) { + indexes.push_back(index); + scores.push_back(*iter); + } + } + return {indexes, scores}; + } + + string Idx2Str(const vector& indexes) { + size_t count = 0; + for (const auto& idx : indexes) { + count += idx2char_[idx].size(); + } + std::string text; + text.reserve(count); + for (const auto& idx : indexes) { + text += idx2char_[idx]; + } + return text; + } + + protected: + static vector SplitLines(const string& s) { + std::istringstream is(s); + vector ret; + string line; + while (std::getline(is, line)) { + ret.push_back(std::move(line)); + } + return ret; + } + + static vector SplitChars(const string& s) { + vector ret; + ret.reserve(s.size()); + for (char c : s) { + ret.push_back({c}); + } + return ret; + } + + static constexpr const auto DICT36 = R"(0123456789abcdefghijklmnopqrstuvwxyz)"; + static constexpr const auto DICT90 = R"(0123456789abcdefghijklmnopqrstuvwxyz)" + R"(ABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'())" + R"(*+,-./:;<=>?@[\]_`~)"; + + static constexpr const auto kHost = Device(0); + + Model model_; + + static constexpr const int blank_idx_{0}; + int padding_idx_{-1}; + int end_idx_{-1}; + int start_idx_{-1}; + int unknown_idx_{-1}; + + vector ignore_indexes_; + vector idx2char_; +}; + +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMOCR, AttnConvertor); + +} // namespace mmdeploy::mmocr diff --git a/mmdeploy/codebase/mmocr/deploy/text_recognition.py b/mmdeploy/codebase/mmocr/deploy/text_recognition.py index cf943560c..f99c21f9f 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_recognition.py +++ b/mmdeploy/codebase/mmocr/deploy/text_recognition.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy from typing import Callable, Dict, Optional, Sequence, Tuple, Union import mmengine @@ -260,11 +261,13 @@ class TextRecognition(BaseTask): postprocess = self.model_cfg.model.decoder.postprocessor if postprocess.type == 'CTCPostProcessor': postprocess.type = 'CTCConvertor' + if postprocess.type == 'AttentionPostprocessor': + postprocess.type = 'AttnConvertor' import shutil shutil.copy(self.model_cfg.dictionary.dict_file, f'{work_dir}/dict_file.txt') - with_padding = self.model_cfg.dictionary.get('with_padding', False) - params = dict(dict_file='dict_file.txt', with_padding=with_padding) + params = deepcopy(self.model_cfg.dictionary) + params.update(dict(dict_file='dict_file.txt')) postprocess['params'] = params return postprocess