基于spark的swing召回工程实现

基于spark的swing召回工程实现

swing原理

基于图结构的实时推荐算法Swing,能够计算item-item之间的相似性。Swing指的是秋千,用户和物品的二部图中会存在很多这种秋千,例如 (u1,u2,i1) , 即用户 u1u2 都购买过物品 i1 ,三者构成一个秋千(三角形缺一条边)。这实际上是3阶交互关系。传统的启发式近邻方法只关注用户和物品之间的二阶交互关系。Swing会关注这种3阶关系。这种方法的一个直觉来源于,如果多个user在点击了 i1 的同时,都只共同点了某一个其他的 i2 ,那么 i1i2 一定是强关联的,这种未知的强关联关系相当于是通过用户来传递的。另一方面,如果两个user pair对之间构成的swing结构越多,则每个结构越弱,在这个pair对上每个节点分到的权重越低。公式如下:

Sim(i,j) = {\sum_{u\subset U_i\bigcap U_j} }{\sum_{u\subset U_i\bigcap U_j}}{\frac{1}{\alpha +|I_u \cap I_v|}}

基于spark dataframe的swing实现

package entries

//  物品信息
case class Item(item_id: String)

//  用户-物品-评分
case class Rating(user_id: String, item_id: String, rating: Double)

//  用户信息
case class User(user_id: String)

SwingModel

package model.gragh

import entries.Rating
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

/**
  * @author jinfeng
  * @param spark SparkSession
  * @return
  * root
  * |-- item_id: string (nullable = true)
  * |-- sorted_items: array (nullable = true)
  * |    |-- element: struct (containsNull = true)
  * |    |    |-- _1: string (nullable = true)
  * |    |    |-- _2: double (nullable = true)
  * @example
  * val model = new SwingModel(spark)
  * .setAlpha(1)
  * .setTop_N_Items(100)
  * .setParallelism(100)
  * val ret = model.fit(df)
  * @version 1.0
  */
class SwingModel(spark: SparkSession) extends Serializable {

  import spark.implicits._

  var defaultParallelism: Int = spark.sparkContext.defaultParallelism
  var similarities: Option[DataFrame] = None
  var alpha: Option[Double] = Option(0.0)
  var top_n_items: Option[Int] = Option(100)

  /**
    * @param parallelism 并行度,不设置,则为spark默认的并行度
    * @return
    */
  def setParallelism(parallelism: Int): SwingModel = {
    this.defaultParallelism = parallelism
    this
  }

  /**
    * @param alpha swing召回模型中的alpha值
    * @return
    */
  def setAlpha(alpha: Double): SwingModel = {
    this.alpha = Option(alpha)
    this
  }

  /**
    * @param top_n_items 计算相似度时,通过count倒排,取前top_n_items个item进行计算
    * @return
    */
  def setTop_N_Items(top_n_items: Int): SwingModel = {
    this.top_n_items = Option(top_n_items)
    this
  }

  /**
    * @param ratings 打分dataset
    * @return
    */
  def fit(ratings: Dataset[Rating]): SwingModel = {

    case class UserWithItemSet(user_id: String, item_set: Seq[String])

    def interWithAlpha = udf(
      (array_1: Seq[GenericRowWithSchema], array_2: Seq[GenericRowWithSchema]) => {
        var score = 0.0
        val set_1 = array_1.toSet
        val set_2 = array_2.toSet
        val user_set = set_1.intersect(set_2).toArray
        for (i <- user_set.indices; j <- i + 1 until user_set.length) {
          val user_1 = user_set(i)
          val user_2 = user_set(j)
          val item_set_1 = user_1.getAs[Seq[String]]("_2").toSet
          val item_set_2 = user_2.getAs[Seq[String]]("_2").toSet
          score = score + 1 / (item_set_1.intersect(item_set_2).size.toDouble + this.alpha.get)
        }
        score
      }
    )

    val df = ratings.repartition(defaultParallelism).cache()
    val groupUsers = df.groupBy("user_id")
      .agg(collect_set("item_id"))
      .toDF("user_id", "item_set")
      .repartition(defaultParallelism)
    val groupItems = df.join(groupUsers, "user_id")
      .rdd.map { x =>
      val item_id = x.getAs[String]("item_id")
      val user_id = x.getAs[String]("user_id")
      val item_set = x.getAs[Seq[String]]("item_set")
      (item_id, (user_id, item_set))
    }.toDF("item_id", "user")
      .groupBy("item_id")
      .agg(collect_set("user"), count("item_id"))
      .toDF("item_id", "user_set", "count")
      .sort($"count".desc)
      .limit(this.top_n_items.get)
      .drop("count")
      .repartition(defaultParallelism)
      .cache()
    val itemJoined = groupItems.join(broadcast(groupItems))
      .toDF("item_id_1", "user_set_1", "item_id_2", "user_set_2")
      .filter("item_id_1 <> item_id_2")
      .withColumn("score", interWithAlpha(col("user_set_1"), col("user_set_2")))
      .select("item_id_1", "item_id_2", "score")
      .filter("score > 0")
      .repartition(defaultParallelism)
      .cache()
    similarities = Option(itemJoined)
    this
  }

  /**
    * 从fit结果,对item_id进行聚合并排序,每个item后截取n个item,并返回。
    *
    * @param num 取n个item
    * @return
    */
  def item2item(num: Int): DataFrame = {
    case class itemWithScore(item_id: String, score: Double)
    val sim = similarities.get.select("item_id_1", "item_id_2", "score")
    val topN = sim
      .map { x =>
        val item_id_1 = x.getAs[String]("item_id_1")
        val item_id_2 = x.getAs[String]("item_id_2")
        val score = x.getAs[Double]("score")
        (item_id_1, (item_id_2, score))
      }.toDF("item_id", "itemWithScore")
      .groupBy("item_id").agg(collect_set("itemWithScore"))
      .toDF("item_id", "item_set")
      .rdd.map { x =>
      val item_id_1 = x.getAs[String]("item_id")
      val item_set = x.getAs[Seq[GenericRowWithSchema]]("item_set")
        .map { x =>
          val item_id_2 = x.getAs[String]("_1")
          val score = x.getAs[Double]("_2")
          (item_id_2, score)
        }.sortBy(-_._2).take(num)
      (item_id_1, item_set)
    }.toDF("item_id", "sorted_items")
      .filter("size(sorted_items) > 0")
    topN
  }
}

Swing main函数

package main

import common.LogDataProcess
import model.gragh.SwingModel
import org.apache.spark.sql.SparkSession

/**
  * @author jinfeng
  * @version 1.0
  */
object Swing {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("SwingModel").getOrCreate()
    spark.sparkContext.setLogLevel("ERROR")
    val Array(log_url, top_n_items, alpha, num, dest_url) = args
    val model = new SwingModel(spark)
      .setAlpha(alpha.toDouble)
      .setTop_N_Items(top_n_items.toInt)
    val ratings = LogDataProcess.getRatingLog(spark, log_url)
    val df = model.fit(ratings).item2item(num.toInt)
    df.write.mode("overwrite").parquet(dest_url)
  }

}
关注公众号:算法工厂

编辑于 2020-11-07 16:35