diff --git a/deploy/python/postprocess.py b/deploy/python/postprocess.py index e6709dcef..3a4e3726b 100644 --- a/deploy/python/postprocess.py +++ b/deploy/python/postprocess.py @@ -164,11 +164,16 @@ class ScoreOutput(object): class Topk(object): - def __init__(self, topk=1, class_id_map_file=None, delimiter=None): + def __init__(self, + topk=1, + class_id_map_file=None, + delimiter=None, + label_list=None): assert isinstance(topk, (int, )) self.topk = topk delimiter = delimiter if delimiter is not None else " " - self.class_id_map = parse_class_id_map(class_id_map_file, delimiter) + self.class_id_map = parse_class_id_map( + class_id_map_file, delimiter) if not label_list else label_list def __call__(self, x, file_names=None): if file_names is not None: diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index bebe43af8..40b2c58df 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -27,7 +27,7 @@ import random from ppcls.utils.misc import AverageMeter from ppcls.utils import logger from ppcls.utils.logger import init_logger -from ppcls.utils.config import print_config +from ppcls.utils.config import print_config, dump_infer_config from ppcls.data import build_dataloader from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer from ppcls.arch import apply_to_static @@ -523,10 +523,9 @@ class Engine(object): else: paddle.jit.save(model, save_path) if self.config["Global"].get("export_for_fd", False): - src_path = self.config["Global"]["infer_config_path"] dst_path = os.path.join( self.config["Global"]["save_inference_dir"], 'inference.yml') - shutil.copy(src_path, dst_path) + dump_infer_config(self.config, dst_path) logger.info( f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"." ) diff --git a/ppcls/utils/config.py b/ppcls/utils/config.py index 1a0fc1933..30bd160a1 100644 --- a/ppcls/utils/config.py +++ b/ppcls/utils/config.py @@ -18,6 +18,7 @@ import argparse import yaml from . import logger from . import check +from collections import OrderedDict __all__ = ['get_config'] @@ -213,3 +214,44 @@ def parse_args(): ) args = parser.parse_args() return args + +def represent_dictionary_order(self, dict_data): + return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items()) + + +def setup_orderdict(): + yaml.add_representer(OrderedDict, represent_dictionary_order) + + +def dump_infer_config(config, path): + setup_orderdict() + infer_cfg = OrderedDict() + transforms = config["Infer"]["transforms"] + for transform in transforms: + if "NormalizeImage" in transform: + transform["NormalizeImage"]["channel_num"] = 3 + infer_cfg["PreProcess"] = { + "transform_ops": [ + infer_preprocess for infer_preprocess in transforms + if "DecodeImage" not in infer_preprocess + ] + } + + postprocess_dict = config["Infer"]["PostProcess"] + with open(postprocess_dict["class_id_map_file"], 'r') as f: + label_id_maps = f.readlines() + label_names = [] + for line in label_id_maps: + line = line.strip().split(' ', 1) + label_names.append(line[1:][0]) + + infer_cfg["PostProcess"] = { + "Topk": OrderedDict({ + "topk": postprocess_dict["topk"], + "label_list": label_names + }) + } + with open(path, 'w') as f: + yaml.dump(infer_cfg, f) + logger.info("Export inference config file to {}".format( + os.path.join(path)))