update config for ShiTu_rec (#3245)

* update config for ShiTu_rec

* update for ShiTu_Rec
pull/3254/head
cuicheng01 2024-09-12 18:27:01 +08:00 committed by GitHub
parent de0f57521d
commit 716fbcc574
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 26 deletions

View File

@ -17,10 +17,14 @@ Global:
use_dali: False
to_static: False
# mixed precision
AMP:
scale_loss: 128
use_amp: False
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
use_promote: False
# O1: mixed fp16, O2: pure fp16
level: O1
# model architecture

View File

@ -17,10 +17,14 @@ Global:
use_dali: False
to_static: False
# mixed precision
AMP:
scale_loss: 128
use_amp: False
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
use_promote: False
# O1: mixed fp16, O2: pure fp16
level: O1
# model architecture

View File

@ -18,10 +18,14 @@ Global:
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# mixed precision
AMP:
scale_loss: 65536
use_amp: False
use_fp16_test: False
scale_loss: 128.0
use_dynamic_loss_scaling: True
# O1: mixed fp16
use_promote: False
# O1: mixed fp16, O2: pure fp16
level: O1
# model architecture

View File

@ -193,5 +193,5 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# for PaddleX
ClsDataset = ImageNetDataset
ShiTuDataset = ImageNetDataset
ShiTuRecDataset = ImageNetDataset
MLClsDataset = MultiLabelDataset

View File

@ -226,7 +226,13 @@ def setup_orderdict():
def dump_infer_config(config, path):
setup_orderdict()
infer_cfg = OrderedDict()
transforms = config["Infer"]["transforms"]
if config.get("Infer"):
transforms = config["Infer"]["transforms"]
elif config["DataLoader"]["Eval"].get("Query"):
transforms = config["DataLoader"]["Eval"]["Query"]["dataset"]["transform_ops"]
transforms.append({"ToCHWImage": None})
else:
logger.error("This config does not support dump transform config!")
for transform in transforms:
if "NormalizeImage" in transform:
transform["NormalizeImage"]["channel_num"] = 3
@ -241,28 +247,30 @@ def dump_infer_config(config, path):
if "DecodeImage" not in infer_preprocess
]
}
if config.get("Infer"):
postprocess_dict = config["Infer"]["PostProcess"]
postprocess_dict = config["Infer"]["PostProcess"]
with open(postprocess_dict["class_id_map_file"], 'r') as f:
label_id_maps = f.readlines()
label_names = []
for line in label_id_maps:
line = line.strip().split(' ', 1)
label_names.append(line[1:][0])
with open(postprocess_dict["class_id_map_file"], 'r') as f:
label_id_maps = f.readlines()
label_names = []
for line in label_id_maps:
line = line.strip().split(' ', 1)
label_names.append(line[1:][0])
postprocess_name = postprocess_dict.get("name", None)
postprocess_dict.pop("class_id_map_file")
postprocess_dict.pop("name")
dic = OrderedDict()
for item in postprocess_dict.items():
dic[item[0]] = item[1]
dic['label_list'] = label_names
postprocess_name = postprocess_dict.get("name", None)
postprocess_dict.pop("class_id_map_file")
postprocess_dict.pop("name")
dic = OrderedDict()
for item in postprocess_dict.items():
dic[item[0]] = item[1]
dic['label_list'] = label_names
if postprocess_name:
infer_cfg["PostProcess"] = {postprocess_name: dic}
if postprocess_name:
infer_cfg["PostProcess"] = {postprocess_name: dic}
else:
raise ValueError("PostProcess name is not specified")
else:
raise ValueError("PostProcess name is not specified")
infer_cfg["PostProcess"] = {"NormalizeFeatures": None}
with open(path, 'w') as f:
yaml.dump(infer_cfg, f)
logger.info("Export inference config file to {}".format(os.path.join(path)))