export with label (#3166)

pull/3199/head
zhangyubo0722 2024-07-25 21:31:21 +08:00 committed by GitHub
parent 5dde31371c
commit fe5700fce6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 51 additions and 5 deletions

View File

@ -164,11 +164,16 @@ class ScoreOutput(object):
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
def __init__(self,
topk=1,
class_id_map_file=None,
delimiter=None,
label_list=None):
assert isinstance(topk, (int, ))
self.topk = topk
delimiter = delimiter if delimiter is not None else " "
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
self.class_id_map = parse_class_id_map(
class_id_map_file, delimiter) if not label_list else label_list
def __call__(self, x, file_names=None):
if file_names is not None:

View File

@ -27,7 +27,7 @@ import random
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
from ppcls.utils.config import print_config, dump_infer_config
from ppcls.data import build_dataloader
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
from ppcls.arch import apply_to_static
@ -523,10 +523,9 @@ class Engine(object):
else:
paddle.jit.save(model, save_path)
if self.config["Global"].get("export_for_fd", False):
src_path = self.config["Global"]["infer_config_path"]
dst_path = os.path.join(
self.config["Global"]["save_inference_dir"], 'inference.yml')
shutil.copy(src_path, dst_path)
dump_infer_config(self.config, dst_path)
logger.info(
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
)

View File

@ -18,6 +18,7 @@ import argparse
import yaml
from . import logger
from . import check
from collections import OrderedDict
__all__ = ['get_config']
@ -213,3 +214,44 @@ def parse_args():
)
args = parser.parse_args()
return args
def represent_dictionary_order(self, dict_data):
return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
def setup_orderdict():
yaml.add_representer(OrderedDict, represent_dictionary_order)
def dump_infer_config(config, path):
setup_orderdict()
infer_cfg = OrderedDict()
transforms = config["Infer"]["transforms"]
for transform in transforms:
if "NormalizeImage" in transform:
transform["NormalizeImage"]["channel_num"] = 3
infer_cfg["PreProcess"] = {
"transform_ops": [
infer_preprocess for infer_preprocess in transforms
if "DecodeImage" not in infer_preprocess
]
}
postprocess_dict = config["Infer"]["PostProcess"]
with open(postprocess_dict["class_id_map_file"], 'r') as f:
label_id_maps = f.readlines()
label_names = []
for line in label_id_maps:
line = line.strip().split(' ', 1)
label_names.append(line[1:][0])
infer_cfg["PostProcess"] = {
"Topk": OrderedDict({
"topk": postprocess_dict["topk"],
"label_list": label_names
})
}
with open(path, 'w') as f:
yaml.dump(infer_cfg, f)
logger.info("Export inference config file to {}".format(
os.path.join(path)))