add rec model for tipc
parent
475a60a37c
commit
a3465a6591
|
@ -35,6 +35,24 @@ class RecPredictor(Predictor):
|
|||
self.preprocess_ops = create_operators(config["RecPreProcess"][
|
||||
"transform_ops"])
|
||||
self.postprocess = build_postprocess(config["RecPostProcess"])
|
||||
self.benchmark = config["Global"].get("benchmark", False)
|
||||
|
||||
import auto_log
|
||||
pid = os.getpid()
|
||||
self.auto_logger = auto_log.AutoLogger(
|
||||
model_name=config["Global"].get("model_name", "rec"),
|
||||
model_precision='fp16' if config["Global"]["use_fp16"] else 'fp32',
|
||||
batch_size=config["Global"].get("batch_size", 1),
|
||||
data_shape=[3, 224, 224],
|
||||
save_path=config["Global"].get("save_log_path", "./auto_log.log"),
|
||||
inference_config=self.config,
|
||||
pids=pid,
|
||||
process_name=None,
|
||||
gpu_ids=None,
|
||||
time_keys=[
|
||||
'preprocess_time', 'inference_time', 'postprocess_time'
|
||||
],
|
||||
warmup=2)
|
||||
|
||||
def predict(self, images, feature_normalize=True):
|
||||
input_names = self.paddle_predictor.get_input_names()
|
||||
|
@ -44,16 +62,22 @@ class RecPredictor(Predictor):
|
|||
output_tensor = self.paddle_predictor.get_output_handle(output_names[
|
||||
0])
|
||||
|
||||
if self.benchmark:
|
||||
self.auto_logger.times.start()
|
||||
if not isinstance(images, (list, )):
|
||||
images = [images]
|
||||
for idx in range(len(images)):
|
||||
for ops in self.preprocess_ops:
|
||||
images[idx] = ops(images[idx])
|
||||
image = np.array(images)
|
||||
if self.benchmark:
|
||||
self.auto_logger.times.stamp()
|
||||
|
||||
input_tensor.copy_from_cpu(image)
|
||||
self.paddle_predictor.run()
|
||||
batch_output = output_tensor.copy_to_cpu()
|
||||
if self.benchmark:
|
||||
self.auto_logger.times.stamp()
|
||||
|
||||
if feature_normalize:
|
||||
feas_norm = np.sqrt(
|
||||
|
@ -62,6 +86,9 @@ class RecPredictor(Predictor):
|
|||
|
||||
if self.postprocess is not None:
|
||||
batch_output = self.postprocess(batch_output)
|
||||
|
||||
if self.benchmark:
|
||||
self.auto_logger.times.end(stamp=True)
|
||||
return batch_output
|
||||
|
||||
|
||||
|
@ -85,16 +112,19 @@ def main(config):
|
|||
batch_names.append(img_name)
|
||||
cnt += 1
|
||||
|
||||
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1) == len(image_list):
|
||||
if len(batch_imgs) == 0:
|
||||
if cnt % config["Global"]["batch_size"] == 0 or (idx + 1
|
||||
) == len(image_list):
|
||||
if len(batch_imgs) == 0:
|
||||
continue
|
||||
|
||||
|
||||
batch_results = rec_predictor.predict(batch_imgs)
|
||||
for number, result_dict in enumerate(batch_results):
|
||||
filename = batch_names[number]
|
||||
print("{}:\t{}".format(filename, result_dict))
|
||||
batch_imgs = []
|
||||
batch_names = []
|
||||
if rec_predictor.benchmark:
|
||||
rec_predictor.auto_logger.report()
|
||||
|
||||
return
|
||||
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
===========================train_params===========================
|
||||
model_name:GeneralRecognition_PPLCNet_x2_5
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
-o Global.device:gpu
|
||||
-o Global.auto_cast:null
|
||||
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
|
||||
-o Global.output_dir:./output/
|
||||
-o DataLoader.Train.sampler.batch_size:8
|
||||
-o Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./dataset/ILSVRC2012/val
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train
|
||||
norm_train:tools/train.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml
|
||||
null:null
|
||||
##
|
||||
===========================infer_params==========================
|
||||
-o Global.save_inference_dir:./inference
|
||||
-o Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c ppcls/configs/GeneralRecognition/GeneralRecognition_PPLCNet_x2_5.yaml
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
kl_quant:null
|
||||
export2:null
|
||||
pretrained_model_url:https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/pretrain/general_PPLCNet_x2_5_pretrained_v1.0.pdparams
|
||||
infer_model:../inference/
|
||||
infer_export:True
|
||||
infer_quant:Fasle
|
||||
inference:python/predict_rec.py -c configs/inference_rec.yaml
|
||||
-o Global.use_gpu:True|False
|
||||
-o Global.enable_mkldnn:True|False
|
||||
-o Global.cpu_num_threads:1|6
|
||||
-o Global.batch_size:1|16
|
||||
-o Global.use_tensorrt:True|False
|
||||
-o Global.use_fp16:True|False
|
||||
-o Global.rec_inference_model_dir:../inference
|
||||
-o Global.infer_imgs:../dataset/Aliproduct/demo_test/
|
||||
-o Global.save_log_path:null
|
||||
-o Global.benchmark:True
|
||||
null:null
|
||||
null:null
|
|
@ -37,6 +37,22 @@ model_name=$(func_parser_value "${lines[1]}")
|
|||
model_url_value=$(func_parser_value "${lines[35]}")
|
||||
model_url_key=$(func_parser_key "${lines[35]}")
|
||||
|
||||
if [[ $FILENAME == *GeneralRecognition* ]];then
|
||||
cd dataset
|
||||
rm -rf Aliproduct
|
||||
rm -rf train_reg_all_data.txt
|
||||
rm -rf demo_train
|
||||
wget -nc https://paddle-imagenet-models-name.bj.bcebos.com/data/whole_chain/tipc_shitu_demo_data.tar
|
||||
tar -xf tipc_shitu_demo_data.tar
|
||||
ln -s tipc_shitu_demo_data Aliproduct
|
||||
ln -s tipc_shitu_demo_data/demo_train.txt train_reg_all_data.txt
|
||||
ln -s tipc_shitu_demo_data/demo_train demo_train
|
||||
cd tipc_shitu_demo_data
|
||||
ln -s demo_test.txt val_list.txt
|
||||
cd ../../
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ ${MODE} = "lite_train_lite_infer" ] || [ ${MODE} = "lite_train_whole_infer" ];then
|
||||
# pretrain lite train data
|
||||
cd dataset
|
||||
|
|
|
@ -291,8 +291,12 @@ else
|
|||
export FLAGS_cudnn_deterministic=True
|
||||
eval $cmd
|
||||
status_check $? "${cmd}" "${status_log}"
|
||||
|
||||
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${$model_name}/${train_model_name}")
|
||||
|
||||
if [[ $FILENAME == *GeneralRecognition* ]]; then
|
||||
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/RecModel/${train_model_name}")
|
||||
else
|
||||
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/${model_name}/${train_model_name}")
|
||||
fi
|
||||
# save norm trained models to set pretrain for pact training and fpgm training
|
||||
if [ ${trainer} = ${trainer_norm} ]; then
|
||||
load_norm_train_model=${set_eval_pretrain}
|
||||
|
@ -308,7 +312,11 @@ else
|
|||
if [ ${run_export} != "null" ]; then
|
||||
# run export model
|
||||
save_infer_path="${save_log}"
|
||||
set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${model_name}/${train_model_name}")
|
||||
if [[ $FILENAME == *GeneralRecognition* ]]; then
|
||||
set_eval_pretrain=$(func_set_params "${pretrain_model_key}" "${save_log}/RecModel/${train_model_name}")
|
||||
else
|
||||
set_export_weight=$(func_set_params "${export_weight}" "${save_log}/${model_name}/${train_model_name}")
|
||||
fi
|
||||
set_save_infer_key=$(func_set_params "${save_infer_key}" "${save_infer_path}")
|
||||
export_cmd="${python} ${run_export} ${set_export_weight} ${set_save_infer_key}"
|
||||
eval $export_cmd
|
||||
|
|
Loading…
Reference in New Issue