# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import yaml from argparse import ArgumentParser, RawDescriptionHelpFormatter import os.path import logging logging.basicConfig(level=logging.INFO) support_list = { "it": "italian", "xi": "spanish", "pu": "portuguese", "ru": "russian", "ar": "arabic", "ta": "tamil", "ug": "uyghur", "fa": "persian", "ur": "urdu", "rs": "serbian latin", "oc": "occitan", "rsc": "serbian cyrillic", "bg": "bulgarian", "uk": "ukranian", "be": "belarusian", "te": "telugu", "ka": "kannada", "chinese_cht": "chinese tradition", "hi": "hindi", "mr": "marathi", "ne": "nepali", } latin_lang = [ "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu", "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc", "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr", "uz", "vi", "latin", ] arabic_lang = ["ar", "fa", "ug", "ur"] cyrillic_lang = [ "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar", "inh", "che", "lbe", "lez", "tab", "cyrillic", ] devanagari_lang = [ "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc", "devanagari", ] multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang assert os.path.isfile( "./rec_multi_language_lite_train.yml" ), "Loss basic configuration file rec_multi_language_lite_train.yml.\ You can download it from \ https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/" global_config = yaml.load( open("./rec_multi_language_lite_train.yml", "rb"), Loader=yaml.Loader ) project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../")) class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter) self.add_argument("-o", "--opt", nargs="+", help="set configuration options") self.add_argument( "-l", "--language", nargs="+", help="set language type, support {}".format(support_list), ) self.add_argument( "--train", type=str, help="you can use this command to change the train dataset default path", ) self.add_argument( "--val", type=str, help="you can use this command to change the eval dataset default path", ) self.add_argument( "--dict", type=str, help="you can use this command to change the dictionary default path", ) self.add_argument( "--data_dir", type=str, help="you can use this command to change the dataset default root path", ) def parse_args(self, argv=None): args = super(ArgsParser, self).parse_args(argv) args.opt = self._parse_opt(args.opt) args.language = self._set_language(args.language) return args def _parse_opt(self, opts): config = {} if not opts: return config for s in opts: s = s.strip() k, v = s.split("=") config[k] = yaml.load(v, Loader=yaml.Loader) return config def _set_language(self, type): lang = type[0] assert type, "please use -l or --language to choose language type" assert lang in support_list.keys() or lang in multi_lang, ( "the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " "please check your running command".format(multi_lang, type) ) if lang in latin_lang: lang = "latin" elif lang in arabic_lang: lang = "arabic" elif lang in cyrillic_lang: lang = "cyrillic" elif lang in devanagari_lang: lang = "devanagari" global_config["Global"][ "character_dict_path" ] = "ppocr/utils/dict/{}_dict.txt".format(lang) global_config["Global"]["save_model_dir"] = "./output/rec_{}_lite".format(lang) global_config["Train"]["dataset"]["label_file_list"] = [ "train_data/{}_train.txt".format(lang) ] global_config["Eval"]["dataset"]["label_file_list"] = [ "train_data/{}_val.txt".format(lang) ] global_config["Global"]["character_type"] = lang assert os.path.isfile( os.path.join(project_path, global_config["Global"]["character_dict_path"]) ), "Loss default dictionary file {}_dict.txt.You can download it from \ https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format( lang ) return lang def merge_config(config): """ Merge config into global config. Args: config (dict): Config to be merged. Returns: global config """ for key, value in config.items(): if "." not in key: if isinstance(value, dict) and key in global_config: global_config[key].update(value) else: global_config[key] = value else: sub_keys = key.split(".") assert ( sub_keys[0] in global_config ), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format( global_config.keys(), sub_keys[0] ) cur = global_config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): if idx == len(sub_keys) - 2: cur[sub_key] = value else: cur = cur[sub_key] def loss_file(path): assert os.path.exists( path ), "There is no such file:{},Please do not forget to put in the specified file".format( path ) if __name__ == "__main__": FLAGS = ArgsParser().parse_args() merge_config(FLAGS.opt) save_file_path = "rec_{}_lite_train.yml".format(FLAGS.language) if os.path.isfile(save_file_path): os.remove(save_file_path) if FLAGS.train: global_config["Train"]["dataset"]["label_file_list"] = [FLAGS.train] train_label_path = os.path.join(project_path, FLAGS.train) loss_file(train_label_path) if FLAGS.val: global_config["Eval"]["dataset"]["label_file_list"] = [FLAGS.val] eval_label_path = os.path.join(project_path, FLAGS.val) loss_file(eval_label_path) if FLAGS.dict: global_config["Global"]["character_dict_path"] = FLAGS.dict dict_path = os.path.join(project_path, FLAGS.dict) loss_file(dict_path) if FLAGS.data_dir: global_config["Eval"]["dataset"]["data_dir"] = FLAGS.data_dir global_config["Train"]["dataset"]["data_dir"] = FLAGS.data_dir data_dir = os.path.join(project_path, FLAGS.data_dir) loss_file(data_dir) with open(save_file_path, "w") as f: yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False) logging.info("Project path is :{}".format(project_path)) logging.info( "Train list path set to :{}".format( global_config["Train"]["dataset"]["label_file_list"][0] ) ) logging.info( "Eval list path set to :{}".format( global_config["Eval"]["dataset"]["label_file_list"][0] ) ) logging.info( "Dataset root path set to :{}".format( global_config["Eval"]["dataset"]["data_dir"] ) ) logging.info( "Dict path set to :{}".format( global_config["Global"]["character_dict_path"] ) ) logging.info( "Config file set to :configs/rec/multi_language/{}".format(save_file_path) )