机器学习模型超参数网格搜索脚本

本文以阿里云机器学习平台上的 ps_smart (GBDT)算法为例,提供一个超参数网格搜索的bash脚本,支持控制并发数。本文提供的示例是一个LTV预测的回归任务,计算MAERMSEWAPE 三个评估指标,并连同模型参数一起写入一张MaxCompute表里。

备注:需要预先安装好 ODPSCMD 客户端工具

#!/bin/bash
#set -e
odps='/home/weisu.yxd/.odps/bin/odpscmd --config=odps_config.ini'
#work_dir=$(readlink -f $(dirname $0))
#echo "work_dir=$work_dir"

## hyper param search config
tree_count=(350 300 250)
max_depth=(5 6)
feature_ratio=(1 0.9 0.8)
l1_reg=(0 1 2)
l2_reg=(0 1)
shrinkage=(0.05 0.1)

hyper_params_file='hyper_params.txt'

function log_info()
{
    if [ "$LOG_LEVEL" != "WARN" ] && [ "$LOG_LEVEL" != "ERROR" ]
    then
        echo "`date +"%Y-%m-%d %H:%M:%S"` [INFO] ($$)($USER): $*";
    fi
}

function grid_search()
{
    log_info "function [$FUNCNAME] begin"
    if [ -f ${hyper_params_file} ]; then
        rm ${hyper_params_file}
    fi
    for tree in ${tree_count[@]}; do
      for depth in ${max_depth[@]}; do
        for fea_ratio in ${feature_ratio[@]}; do
          for l1 in ${l1_reg[@]}; do
            l1_is_zero=$(expr ${l1} = 0)
            for l2 in ${l2_reg[@]}; do
              l2_is_zero=$(expr ${l2} = 0)
              if [ "${l1_is_zero}" -ne 0 -a "${l2_is_zero}" -ne 0 ]; then
                continue
              fi
              for lr in ${shrinkage[@]}; do
                echo ${tree} ${depth} ${fea_ratio} ${l1} ${l2} ${lr} >> ${hyper_params_file}
              done
            done
          done
        done
      done
    done
    log_info "function [$FUNCNAME] end"
}

function run_job() {
    log_info "function [$FUNCNAME] begin"
    local tree_count=$1
    local max_depth=$2
    local fea_ratio=$3
    local l1=$4
    local l2=$5
    local lr=$6
    local model=${tree_count}_${max_depth}_${fea_ratio/0./p}_${l1/0./p}_${l2/0./p}_${lr/0./p}
    log_info "run model: $model"

    $odps -e "drop table if exists ps_${model};"
    $odps -e "drop table if exists model_output_${model};"
    $odps -e "drop table if exists imp_${model};"
    $odps -e "PAI -name ps_smart
    -project algo_public
    -DinputTableName='${input_train_table}'
    -DmodelName='ps_${model}'
    -DoutputTableName='model_output_${model}'
    -DoutputImportanceTableName='imp_${model}'
    -DlabelColName='targetprice'
    -DfeatureColNames='kv'
    -DenableSparse='true'
    -Dobjective='reg:tweedie'
    -Dmetric='tweedie-nloglik'
    -DfeatureImportanceType='gain'
    -DtreeCount='${tree_count}'
    -DmaxDepth='${max_depth}'
    -Dshrinkage='${lr}'
    -Dl2='${l2}'
    -Dl1='${l1}'
    -Dlifecycle='31'
    -DsketchEps='0.03'
    -DsampleRatio='1.0'
    -DfeatureRatio='${fea_ratio}'
    -DbaseScore='0.0'
    -DminSplitLoss='0'
    "
    if [ $? -ne 0 ]; then
        return $?
    fi

    $odps -e "drop table if exists output_${model};"
    $odps -e "PAI -name prediction
    -project algo_public
    -DinputTableName='${input_test_table}'
    -DmodelName='ps_${model}'
    -DoutputTableName='output_${model}'
    -DfeatureColNames='kv'
    -DappendColNames='targetprice'
    -DenableSparse='true'
    -DitemDelimiter=','
    -Dlifecycle='128'
    "
    if [ $? -ne 0 ]; then
        return $?
    fi

    # 需要预先创建好metric结果分区表
    $odps -e "INSERT OVERWRITE TABLE ps_smart_0804_ltv PARTITION(pt='${model}')
    SELECT AVG(ABS(targetprice-prediction_result)) MAE,
        SQRT(AVG((targetprice-prediction_result)*(targetprice-prediction_result))) RMSE,
        SUM(ABS(targetprice-prediction_result))/SUM(ABS(targetprice)) WAPE
    FROM output_${model};"
    log_info "function [$FUNCNAME] end"
}

function run_from_file()
{
    log_info "function [$FUNCNAME] begin"
    threadTask=10 #并发数
    fifoFile="test_fifo"
    rm -f ${fifoFile}
    mkfifo ${fifoFile}  #创建fifo管道
    exec 9<> ${fifoFile}
    rm -f ${fifoFile}
    # 预先向管道写入数据
    for ((i=0;i<${threadTask};i++))
    do
        echo "" >&9
    done

    log_info "wait all task finish,then exit!!!"
    while read line
    do
        read -u9
        {
            run_job $line
            echo "" >&9
        } &
    done < $1
    wait

    exec 9<&-  # 关闭文件描述符的读
    exec 9>&-  # 关闭文件描述符的写
    log_info "function [$FUNCNAME] end"
}

grid_search
run_from_file ${hyper_params_file}

编辑于 2022-08-10 15:45

文章被以下专栏收录