fix_scipts
parent
478a019da3
commit
35ee32e32f
|
@ -7,7 +7,7 @@ function _set_params(){
|
|||
batch_size=${2:-"64"}
|
||||
fp_item=${3:-"fp32"} # fp32|fp16
|
||||
epochs=${4:-"2"} # 可选,如果需要修改代码提前中断
|
||||
model_name=${5:-"model_name"}
|
||||
model_item=${5:-"model_item"}
|
||||
run_log_path=${TRAIN_LOG_DIR:-$(pwd)} # TRAIN_LOG_DIR 后续QA设置该参数
|
||||
|
||||
index=1
|
||||
|
@ -23,16 +23,17 @@ function _set_params(){
|
|||
device=${CUDA_VISIBLE_DEVICES//,/ }
|
||||
arr=(${device})
|
||||
num_gpu_devices=${#arr[*]}
|
||||
log_file=${run_log_path}/clas_${model_name}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
|
||||
log_file=${run_log_path}/clas_${model_item}_${run_mode}_bs${batch_size}_${fp_item}_${num_gpu_devices}
|
||||
model_name=${model_item}_bs${batch_size}_${fp_item} # model_item 用于yml匹配,model_name用于入库
|
||||
}
|
||||
function _train(){
|
||||
echo "Train on ${num_gpu_devices} GPUs"
|
||||
echo "current CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES, gpus=$num_gpu_devices, batch_size=$batch_size"
|
||||
|
||||
if [ ${fp_item} = "fp32" ];then
|
||||
model_config=`find ppcls/configs/ImageNet -name ${model_name}.yaml`
|
||||
model_config=`find ppcls/configs/ImageNet -name ${model_item}.yaml`
|
||||
else
|
||||
model_config=`find ppcls/configs/ImageNet -name ${model_name}_fp16.yaml`
|
||||
model_config=`find ppcls/configs/ImageNet -name ${model_item}_fp16.yaml`
|
||||
fi
|
||||
|
||||
train_cmd="-c ${model_config} -o DataLoader.Train.sampler.batch_size=${batch_size} -o Global.epochs=${epochs} -o Global.eval_during_train=False -o Global.print_batch_step=2"
|
||||
|
|
Loading…
Reference in New Issue