机器学习模型超参数网格搜索脚本
本文以阿里云机器学习平台上的 ps_smart
(GBDT)算法为例,提供一个超参数网格搜索的bash脚本,支持控制并发数。本文提供的示例是一个LTV预测的回归任务,计算MAE
、RMSE
、WAPE
三个评估指标,并连同模型参数一起写入一张MaxCompute表里。
备注:需要预先安装好 ODPSCMD 客户端工具
#!/bin/bash
#set -e
odps='/home/weisu.yxd/.odps/bin/odpscmd --config=odps_config.ini'
#work_dir=$(readlink -f $(dirname $0))
#echo "work_dir=$work_dir"
## hyper param search config
tree_count=(350 300 250)
max_depth=(5 6)
feature_ratio=(1 0.9 0.8)
l1_reg=(0 1 2)
l2_reg=(0 1)
shrinkage=(0.05 0.1)
hyper_params_file='hyper_params.txt'
function log_info()
{
if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
then
echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
fi
}
function grid_search()
{
log_info "function [$FUNCNAME] begin"
if [ -f ${hyper_params_file} ]; then
rm ${hyper_params_file}
fi
for tree in ${tree_count[@]}; do
for depth in ${max_depth[@]}; do
for fea_ratio in ${feature_ratio[@]}; do
for l1 in ${l1_reg[@]}; do
l1_is_zero=$(expr ${l1} = 0)
for l2 in ${l2_reg[@]}; do
l2_is_zero=$(expr ${l2} = 0)
if [ "${l1_is_zero}" -ne 0 -a "${l2_is_zero}" -ne 0 ]; then
continue
fi
for lr in ${shrinkage[@]}; do
echo ${tree} ${depth} ${fea_ratio} ${l1} ${l2} ${lr} >> ${hyper_params_file}
done
done
done
done
done
done
log_info "function [$FUNCNAME] end"
}
function run_job() {
log_info "function [$FUNCNAME] begin"
local tree_count=$1
local max_depth=$2
local fea_ratio=$3
local l1=$4
local l2=$5
local lr=$6
local model=${tree_count}_${max_depth}_${fea_ratio/0./p}_${l1/0./p}_${l2/0./p}_${lr/0./p}
log_info "run model: $model"
$odps -e "drop table if exists ps_${model};"
$odps -e "drop table if exists model_output_${model};"
$odps -e "drop table if exists imp_${model};"
$odps -e "PAI -name ps_smart
-project algo_public
-DinputTableName='${input_train_table}'
-DmodelName='ps_${model}'
-DoutputTableName='model_output_${model}'
-DoutputImportanceTableName='imp_${model}'
-DlabelColName='targetprice'
-DfeatureColNames='kv'
-DenableSparse='true'
-Dobjective='reg:tweedie'
-Dmetric='tweedie-nloglik'
-DfeatureImportanceType='gain'
-DtreeCount='${tree_count}'
-DmaxDepth='${max_depth}'
-Dshrinkage='${lr}'
-Dl2='${l2}'
-Dl1='${l1}'
-Dlifecycle='31'
-DsketchEps='0.03'
-DsampleRatio='1.0'
-DfeatureRatio='${fea_ratio}'
-DbaseScore='0.0'
-DminSplitLoss='0'
"
if [ $? -ne 0 ]; then
return $?
fi
$odps -e "drop table if exists output_${model};"
$odps -e "PAI -name prediction
-project algo_public
-DinputTableName='${input_test_table}'
-DmodelName='ps_${model}'
-DoutputTableName='output_${model}'
-DfeatureColNames='kv'
-DappendColNames='targetprice'
-DenableSparse='true'
-DitemDelimiter=','
-Dlifecycle='128'
"
if [ $? -ne 0 ]; then
return $?
fi
# 需要预先创建好metric结果分区表
$odps -e "INSERT OVERWRITE TABLE ps_smart_0804_ltv PARTITION(pt='${model}')
SELECT AVG(ABS(targetprice-prediction_result)) MAE,
SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
FROM output_${model};"
log_info "function [$FUNCNAME] end"
}
function run_from_file()
{
log_info "function [$FUNCNAME] begin"
threadTask=10 #并发数
fifoFile="test_fifo"
rm -f ${fifoFile}
mkfifo ${fifoFile} #创建fifo管道
exec 9<> ${fifoFile}
rm -f ${fifoFile}
# 预先向管道写入数据
for ((i=0;i<${threadTask};i++))
do
echo "" >&9
done
log_info "wait all task finish,then exit!!!"
while read line
do
read -u9
{
run_job $line
echo "" >&9
} &
done < $1
wait
exec 9<&- # 关闭文件描述符的读
exec 9>&- # 关闭文件描述符的写
log_info "function [$FUNCNAME] end"
}
grid_search
run_from_file ${hyper_params_file}
编辑于 2022-08-10 15:45