update config for ShiTu_rec (#3245)
* update config for ShiTu_rec * update for ShiTu_Recpull/3254/head
parent
de0f57521d
commit
716fbcc574
|
@ -17,10 +17,14 @@ Global:
|
|||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
# mixed precision
|
||||
AMP:
|
||||
scale_loss: 128
|
||||
use_amp: False
|
||||
use_fp16_test: False
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
use_promote: False
|
||||
# O1: mixed fp16, O2: pure fp16
|
||||
level: O1
|
||||
|
||||
# model architecture
|
||||
|
|
|
@ -17,10 +17,14 @@ Global:
|
|||
use_dali: False
|
||||
to_static: False
|
||||
|
||||
# mixed precision
|
||||
AMP:
|
||||
scale_loss: 128
|
||||
use_amp: False
|
||||
use_fp16_test: False
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
use_promote: False
|
||||
# O1: mixed fp16, O2: pure fp16
|
||||
level: O1
|
||||
|
||||
# model architecture
|
||||
|
|
|
@ -18,10 +18,14 @@ Global:
|
|||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: ./inference
|
||||
|
||||
# mixed precision
|
||||
AMP:
|
||||
scale_loss: 65536
|
||||
use_amp: False
|
||||
use_fp16_test: False
|
||||
scale_loss: 128.0
|
||||
use_dynamic_loss_scaling: True
|
||||
# O1: mixed fp16
|
||||
use_promote: False
|
||||
# O1: mixed fp16, O2: pure fp16
|
||||
level: O1
|
||||
|
||||
# model architecture
|
||||
|
|
|
@ -193,5 +193,5 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
|
|||
|
||||
# for PaddleX
|
||||
ClsDataset = ImageNetDataset
|
||||
ShiTuDataset = ImageNetDataset
|
||||
ShiTuRecDataset = ImageNetDataset
|
||||
MLClsDataset = MultiLabelDataset
|
||||
|
|
|
@ -226,7 +226,13 @@ def setup_orderdict():
|
|||
def dump_infer_config(config, path):
|
||||
setup_orderdict()
|
||||
infer_cfg = OrderedDict()
|
||||
transforms = config["Infer"]["transforms"]
|
||||
if config.get("Infer"):
|
||||
transforms = config["Infer"]["transforms"]
|
||||
elif config["DataLoader"]["Eval"].get("Query"):
|
||||
transforms = config["DataLoader"]["Eval"]["Query"]["dataset"]["transform_ops"]
|
||||
transforms.append({"ToCHWImage": None})
|
||||
else:
|
||||
logger.error("This config does not support dump transform config!")
|
||||
for transform in transforms:
|
||||
if "NormalizeImage" in transform:
|
||||
transform["NormalizeImage"]["channel_num"] = 3
|
||||
|
@ -241,28 +247,30 @@ def dump_infer_config(config, path):
|
|||
if "DecodeImage" not in infer_preprocess
|
||||
]
|
||||
}
|
||||
if config.get("Infer"):
|
||||
postprocess_dict = config["Infer"]["PostProcess"]
|
||||
|
||||
postprocess_dict = config["Infer"]["PostProcess"]
|
||||
with open(postprocess_dict["class_id_map_file"], 'r') as f:
|
||||
label_id_maps = f.readlines()
|
||||
label_names = []
|
||||
for line in label_id_maps:
|
||||
line = line.strip().split(' ', 1)
|
||||
label_names.append(line[1:][0])
|
||||
with open(postprocess_dict["class_id_map_file"], 'r') as f:
|
||||
label_id_maps = f.readlines()
|
||||
label_names = []
|
||||
for line in label_id_maps:
|
||||
line = line.strip().split(' ', 1)
|
||||
label_names.append(line[1:][0])
|
||||
|
||||
postprocess_name = postprocess_dict.get("name", None)
|
||||
postprocess_dict.pop("class_id_map_file")
|
||||
postprocess_dict.pop("name")
|
||||
dic = OrderedDict()
|
||||
for item in postprocess_dict.items():
|
||||
dic[item[0]] = item[1]
|
||||
dic['label_list'] = label_names
|
||||
postprocess_name = postprocess_dict.get("name", None)
|
||||
postprocess_dict.pop("class_id_map_file")
|
||||
postprocess_dict.pop("name")
|
||||
dic = OrderedDict()
|
||||
for item in postprocess_dict.items():
|
||||
dic[item[0]] = item[1]
|
||||
dic['label_list'] = label_names
|
||||
|
||||
if postprocess_name:
|
||||
infer_cfg["PostProcess"] = {postprocess_name: dic}
|
||||
if postprocess_name:
|
||||
infer_cfg["PostProcess"] = {postprocess_name: dic}
|
||||
else:
|
||||
raise ValueError("PostProcess name is not specified")
|
||||
else:
|
||||
raise ValueError("PostProcess name is not specified")
|
||||
|
||||
infer_cfg["PostProcess"] = {"NormalizeFeatures": None}
|
||||
with open(path, 'w') as f:
|
||||
yaml.dump(infer_cfg, f)
|
||||
logger.info("Export inference config file to {}".format(os.path.join(path)))
|
||||
|
|
Loading…
Reference in New Issue