export with label (#3166)
parent
5dde31371c
commit
fe5700fce6
|
@ -164,11 +164,16 @@ class ScoreOutput(object):
|
|||
|
||||
|
||||
class Topk(object):
|
||||
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
|
||||
def __init__(self,
|
||||
topk=1,
|
||||
class_id_map_file=None,
|
||||
delimiter=None,
|
||||
label_list=None):
|
||||
assert isinstance(topk, (int, ))
|
||||
self.topk = topk
|
||||
delimiter = delimiter if delimiter is not None else " "
|
||||
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
|
||||
self.class_id_map = parse_class_id_map(
|
||||
class_id_map_file, delimiter) if not label_list else label_list
|
||||
|
||||
def __call__(self, x, file_names=None):
|
||||
if file_names is not None:
|
||||
|
|
|
@ -27,7 +27,7 @@ import random
|
|||
from ppcls.utils.misc import AverageMeter
|
||||
from ppcls.utils import logger
|
||||
from ppcls.utils.logger import init_logger
|
||||
from ppcls.utils.config import print_config
|
||||
from ppcls.utils.config import print_config, dump_infer_config
|
||||
from ppcls.data import build_dataloader
|
||||
from ppcls.arch import build_model, RecModel, DistillationModel, TheseusLayer
|
||||
from ppcls.arch import apply_to_static
|
||||
|
@ -523,10 +523,9 @@ class Engine(object):
|
|||
else:
|
||||
paddle.jit.save(model, save_path)
|
||||
if self.config["Global"].get("export_for_fd", False):
|
||||
src_path = self.config["Global"]["infer_config_path"]
|
||||
dst_path = os.path.join(
|
||||
self.config["Global"]["save_inference_dir"], 'inference.yml')
|
||||
shutil.copy(src_path, dst_path)
|
||||
dump_infer_config(self.config, dst_path)
|
||||
logger.info(
|
||||
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
|
||||
)
|
||||
|
|
|
@ -18,6 +18,7 @@ import argparse
|
|||
import yaml
|
||||
from . import logger
|
||||
from . import check
|
||||
from collections import OrderedDict
|
||||
|
||||
__all__ = ['get_config']
|
||||
|
||||
|
@ -213,3 +214,44 @@ def parse_args():
|
|||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def represent_dictionary_order(self, dict_data):
|
||||
return self.represent_mapping('tag:yaml.org,2002:map', dict_data.items())
|
||||
|
||||
|
||||
def setup_orderdict():
|
||||
yaml.add_representer(OrderedDict, represent_dictionary_order)
|
||||
|
||||
|
||||
def dump_infer_config(config, path):
|
||||
setup_orderdict()
|
||||
infer_cfg = OrderedDict()
|
||||
transforms = config["Infer"]["transforms"]
|
||||
for transform in transforms:
|
||||
if "NormalizeImage" in transform:
|
||||
transform["NormalizeImage"]["channel_num"] = 3
|
||||
infer_cfg["PreProcess"] = {
|
||||
"transform_ops": [
|
||||
infer_preprocess for infer_preprocess in transforms
|
||||
if "DecodeImage" not in infer_preprocess
|
||||
]
|
||||
}
|
||||
|
||||
postprocess_dict = config["Infer"]["PostProcess"]
|
||||
with open(postprocess_dict["class_id_map_file"], 'r') as f:
|
||||
label_id_maps = f.readlines()
|
||||
label_names = []
|
||||
for line in label_id_maps:
|
||||
line = line.strip().split(' ', 1)
|
||||
label_names.append(line[1:][0])
|
||||
|
||||
infer_cfg["PostProcess"] = {
|
||||
"Topk": OrderedDict({
|
||||
"topk": postprocess_dict["topk"],
|
||||
"label_list": label_names
|
||||
})
|
||||
}
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(infer_cfg, f)
|
||||
logger.info("Export inference config file to {}".format(
|
||||
os.path.join(path)))
|
||||
|
|
Loading…
Reference in New Issue