fix_scipts

pull/1513/head
hysunflower 2021-12-01 11:42:43 +00:00
parent 478a019da3
commit 35ee32e32f
1 changed files with 5 additions and 4 deletions

View File

@ -7,7 +7,7 @@ function _set_params(){
batch_size=${2:-"64"}
fp_item=${3:-"fp32"} # fp32|fp16
epochs=${4:-"2"} # 可选,如果需要修改代码提前中断
model_name=${5:-"model_name"}
model_item=${5:-"model_item"}
run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数
index=1
@ -23,16 +23,17 @@ function _set_params(){
device=${CUDA_VISIBLE_DEVICES//,/ }
arr=(${device})
num_gpu_devices=${#arr[*]}
log_file=${run_log_path}/clas_${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
log_file=${run_log_path}/clas_${model_item}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
model_name=${model_item}_bs${batch_size}_${fp_item} # model_item 用于yml匹配model_name用于入库
}
function _train(){
echo "Train on ${num_gpu_devices} GPUs"
echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size"
if [ ${fp_item} = "fp32" ];then
model_config=`find ppcls/configs/ImageNet -name ${model_name}.yaml`
model_config=`find ppcls/configs/ImageNet -name ${model_item}.yaml`
else
model_config=`find ppcls/configs/ImageNet -name ${model_name}_fp16.yaml`
model_config=`find ppcls/configs/ImageNet -name ${model_item}_fp16.yaml`
fi
train_cmd="-c ${model_config} -o DataLoader.Train.sampler.batch_size=${batch_size} -o Global.epochs=${epochs} -o Global.eval_during_train=False -o Global.print_batch_step=2"