• Spark ML 之 推荐算法项目(上)


    一、整体流程

    二、具体召回流程

    三、代码实现

     0、过滤已下架的/成人用品/烟酒等

    package com.njbdqn.filter
    
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.sql.SparkSession
    
    object BanGoodFilter {
    
      /**
       *  清洗不能推荐的商品,把留下的商品存在HDFS上保存
       * @param spark
       */
      def ban(spark:SparkSession): Unit ={
       // 读出原始数据
       val goodsDf = MYSQLConnection.readMySql(spark, "goods")
        // 过滤下架商品(已经卖过),把未卖的商品存放到HDFS
        val gd = goodsDf.filter("is_sale=0")
        HDFSConnection.writeDataToHDFS("/myshops/dwd_good",gd)
      }
    }

     1、根据热点全局召回,cross join到每个用户(使每个用户都有可以推荐的)

    package com.njbdqn.call
    
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.{desc, row_number, sum}
    
    /**
     *  全局召回
     */
    object GlobalHotCall {
    
      def hotSell(spark:SparkSession): Unit ={
        val oitab = MYSQLConnection.readMySql(spark, "orderItems").cache()
        // 计算全局热卖商品前100名 ( good_id,sellnum )
        import spark.implicits._
        val top100 = oitab
          .groupBy("good_id")
          .agg(sum("buy_num").alias("sellnum"))
          .withColumn("rank",row_number().over(Window.orderBy(desc("sellnum"))))
          .limit(100)
        // 所有用户id和推荐前30名cross join
        val custstab = MYSQLConnection.readMySql(spark,"customs")
          .select($"cust_id").cache()
        val res = custstab.crossJoin(top100)
        HDFSConnection.writeDataToHDFS("/myshops/dwd_hotsell",res)
        // 针对游客:热卖前10放到 mysql的 hotsell
        val hotsell =top100.limit(10)
        MYSQLConnection.writeTable(spark,hotsell,"hotsell")
      }
    }

     2、分组召回

    详细见 https://www.cnblogs.com/sabertobih/p/13824739.html

    数据处理,归一化:

    package com.njbdqn.datahandler
    
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.ml.clustering.KMeans
    import org.apache.spark.ml.feature.{MinMaxScaler, StringIndexer, VectorAssembler}
    import org.apache.spark.sql.{DataFrame, SparkSession}
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.{col, count, current_date, datediff, desc, min, row_number, sum, udf}
    import org.apache.spark.sql.types.DoubleType
    
    object KMeansHandler {
      val func_membership = udf {
        (score: Int) => {
          score match {
            case i if i < 100 => 1
            case i if i < 500 => 2
            case i if i < 1000 => 3
            case _ => 4
          }
        }
      }
      val func_bir = udf {
        (idno: String, now: String) => {
          val year = idno.substring(6, 10).toInt
          val month = idno.substring(10, 12).toInt
          val day = idno.substring(12, 14).toInt
    
          val dts = now.split("-")
          val nowYear = dts(0).toInt
          val nowMonth = dts(1).toInt
          val nowDay = dts(2).toInt
    
          if (nowMonth > month) {
            nowYear - year
          } else if (nowMonth < month) {
            nowYear - 1 - year
          } else {
            if (nowDay >= day) {
              nowYear - year
            } else {
              nowYear - 1 - year
            }
          }
        }
      }
      val func_age = udf {
        (num: Int) => {
          num match {
            case n if n < 10 => 1
            case n if n < 18 => 2
            case n if n < 23 => 3
            case n if n < 35 => 4
            case n if n < 50 => 5
            case n if n < 70 => 6
            case _ => 7
          }
        }
      }
      val func_userscore = udf {
        (sc: Int) => {
          sc match {
            case s if s < 100 => 1
            case s if s < 500 => 2
            case _ => 3
          }
        }
      }
      val func_logincount = udf {
        (sc: Int) => {
          sc match {
            case s if s < 500 => 1
            case _ => 2
          }
        }
      }
    
      // 整合用户自然属性和行为
      def user_act_info(spark:SparkSession): DataFrame ={
        val featureDataTable = MYSQLConnection.readMySql(spark,"customs").filter("active!=0")
          .select("cust_id", "company", "province_id", "city_id", "district_id"
            , "membership_level", "create_at", "last_login_time", "idno", "biz_point", "sex", "marital_status", "education_id"
            , "login_count", "vocation", "post")
        //商品表
        val goodTable=HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
        //订单表
        val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
        //订单明细表
        val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
        //先将公司名通过StringIndex转为数字
        val compIndex = new StringIndexer().setInputCol("company").setOutputCol("compId")
        //使用自定义UDF函数
        import spark.implicits._
        //计算每个用户购买的次数
        val tmp_bc=orderTable.groupBy("cust_id").agg(count($"ord_id").as("buycount"))
        //计算每个用户在网站上花费了多少钱
        val tmp_pay=orderTable.join(orddetailTable,Seq("ord_id"),"inner").join(goodTable,Seq("good_id"),"inner").groupBy("cust_id").
          agg(sum($"buy_num"*$"price").as("pay"))
    
        compIndex.fit(featureDataTable).transform(featureDataTable)
          .withColumn("mslevel", func_membership($"membership_level"))
          .withColumn("min_reg_date", min($"create_at") over())
          .withColumn("reg_date", datediff($"create_at", $"min_reg_date"))
          .withColumn("min_login_time", min("last_login_time") over())
          .withColumn("lasttime", datediff($"last_login_time", $"min_login_time"))
          .withColumn("age", func_age(func_bir($"idno", current_date())))
          .withColumn("user_score", func_userscore($"biz_point"))
          .withColumn("logincount", func_logincount($"login_count"))
          // 右表:有的用户可能没有买/没花钱,缺少cust_id,所以是left join,以多的为准
          .join(tmp_bc,Seq("cust_id"),"left").join(tmp_pay,Seq("cust_id"),"left")
          .na.fill(0)
          .drop("company", "membership_level", "create_at", "min_reg_date"
            , "last_login_time", "min_login_time", "idno", "biz_point", "login_count")
      }
      // 用户分组
      def user_group(spark:SparkSession) = {
        val df = user_act_info(spark)
        //将所有列换成 Double
        val columns=df.columns.map(f=>col(f).cast(DoubleType))
        val num_fmt=df.select(columns:_*)
        //将除了第一列的所有列都组装成一个向量列
        val va= new VectorAssembler()
          .setInputCols(Array("province_id","city_id","district_id","sex","marital_status","education_id","vocation","post","compId","mslevel","reg_date","lasttime","age","user_score","logincount","buycount","pay"))
          .setOutputCol("orign_feature")
        val ofdf=va.transform(num_fmt).select("cust_id","orign_feature")
        //将原始特征列归一化处理
        val mmScaler:MinMaxScaler=new MinMaxScaler().setInputCol("orign_feature").setOutputCol("feature")
        //fit产生模型 把ofdf放到模型里使用
        mmScaler.fit(ofdf)
          .transform(ofdf)
          .select("cust_id","feature")
    
      }
    }

    kmeans计算分组召回:

    package com.njbdqn.call
    
    import com.njbdqn.datahandler.KMeansHandler
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.ml.clustering.KMeans
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._
    
    
    /**
     *  计算用户分组
     */
    object GroupCall {
      def calc_groups(spark:SparkSession): Unit ={
        //使用Kmeans算法进行分组
        //计算根据不同的质心点计算所有的距离
        //记录不同质心点距离的集合
        //    val disList:ListBuffer[Double]=ListBuffer[Double]()
        //    for (i<-2 to 40){
        //      val kms=new KMeans().setFeaturesCol("feature").setK(i)
        //      val model=kms.fit(resdf)
        //    // 为什么不transform ??
        //      // 目的不是产生df:cust_id,feature和对应的group(prediction)
        //      // 目的是用computeCost算K数量对应的[SSD]
        //      disList.append(model.computeCost(resdf))
        //    }
        //    //调用绘图工具绘图
        //    val chart=new LineGraph("app","Kmeans质心和距离",disList)
        //    chart.pack()
        //    RefineryUtilities.centerFrameOnScreen(chart)
        //    chart.setVisible(true)
    
        import spark.implicits._
        val orderTable=MYSQLConnection.readMySql(spark,"orders").select("ord_id","cust_id")
        val orddetailTable=MYSQLConnection.readMySql(spark,"orderItems").select("ord_id","good_id","buy_num")
        val resdf = KMeansHandler.user_group(spark)
            //使用 Kmeans 进行分组:找一个稳定的 K 值
            val kms=new KMeans().setFeaturesCol("feature").setK(40)
            // 每个用户所属的组 (cust_id,groups) (1,0)
            val user_group_tab=kms.fit(resdf)
              .transform(resdf)
              .drop("feature").
              withColumnRenamed("prediction","groups").cache()
    
            //获取每组用户购买的前30名商品
            // row_number 根据组分组,买的次数desc
            // groupby 组和商品,count买的次数order_id
            val rank=30
            val wnd=Window.partitionBy("groups").orderBy(desc("group_buy_count"))
    
        val groups_goods = user_group_tab.join(orderTable,Seq("cust_id"),"inner")
             .join(orddetailTable,Seq("ord_id"),"inner")
              .na.fill(0)
              .groupBy("groups","good_id")
              .agg(count("ord_id").as("group_buy_count"))
              .withColumn("rank",row_number()over(wnd))
              .filter($"rank"<=rank)
            // 每个用户所属组推荐的商品(是为每个用户推荐的)
        val df5 = user_group_tab.join(groups_goods,Seq("groups"),"inner")
              HDFSConnection.writeDataToHDFS("/myshops/dwd_kMeans",df5)
      }
    }

     3、ALS协同过滤召回

    ALS数据预处理:User-Item稀疏矩阵中score需要量化成数字,每列都需要全数字,稀疏表=> Rating集合

    package com.njbdqn.datahandler
    
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.mllib.recommendation.Rating
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{DataFrame, SparkSession}
    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions.{row_number, sum, udf}
    
    object ALSDataHandler {
      // 为了防止用户编号或商品编号中含有非数字情况,要对所有的商品和用户编号给一个连续的对应的数字编号后再存到缓存
      def goods_to_num(spark:SparkSession):DataFrame={
        import spark.implicits._
        val wnd1 = Window.orderBy("good_id")
        HDFSConnection.readDataToHDFS(spark,"/myshops/dwd_good").select("good_id","price")
          .select($"good_id",row_number().over(wnd1).alias("gid")).cache()
      }
    
      def user_to_num(spark:SparkSession):DataFrame={
        import spark.implicits._
        val wnd2 = Window.orderBy("cust_id")
        MYSQLConnection.readMySql(spark,"customs")
          .select($"cust_id",row_number().over(wnd2).alias("uid")).cache()
      }
    
      val actToNum=udf{
        (str:String)=>{
          str match {
            case "BROWSE"=>1
            case "COLLECT"=>2
            case "BUYCAR"=>3
            case _=>8
          }
        }
      }
    
      case class UserAction(act:String,act_time:String,cust_id:String,good_id:String,browse:String)
    
      def als_data(spark:SparkSession): RDD[Rating] ={
        val goodstab:DataFrame = goods_to_num(spark)
        val custstab:DataFrame = user_to_num(spark)
        val txt = spark.sparkContext.textFile("file:///D:/logs/virtualLogs/*.log").cache()
        import spark.implicits._
        // 计算出每个用户对该用户接触过的商品的评分
        val df = txt.map(line=>{
          val arr = line.split(" ")
          UserAction(arr(0),arr(1),arr(2),arr(3),arr(4))
        }).toDF().drop("act_time","browse")
          .select($"cust_id",$"good_id",actToNum($"act").alias("score"))
          .groupBy("cust_id","good_id")
          .agg(sum($"score").alias("score"))
        // 为了防止用户编号或商品编号中含有非数字情况,要对所有的商品和用户编号给一个连续的对应的数字编号后再存到缓存
        // 将df和goodstab、custtab join一下只保留 (gid,uid,score)
        val df2 = df.join(goodstab,Seq("good_id"),"inner")
          .join(custstab,Seq("cust_id"),"inner")
          .select("gid","uid","score")
        //.show(20)
        // 将稀疏表转为 Rating对象集合
        val allData:RDD[Rating] = df2.rdd.map(row=>{
          Rating(
            row.getAs("uid").toString.toInt,
            row.getAs("gid").toString.toInt,
            row.getAs("score").toString.toFloat
          )})
        allData
      }
    }

    ALS训练,最后需要还原数据(数字=>非数字)

    package com.njbdqn.call
    
    import com.njbdqn.datahandler.ALSDataHandler
    import com.njbdqn.datahandler.ALSDataHandler.{goods_to_num, user_to_num}
    import com.njbdqn.util.{HDFSConnection, MYSQLConnection}
    import org.apache.spark.mllib.recommendation.{ALS, Rating}
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{DataFrame, SparkSession}
    
    object ALSCall {
      def als_call(spark:SparkSession): Unit ={
        val goodstab:DataFrame = goods_to_num(spark)
        val custstab:DataFrame = user_to_num(spark)
        val alldata: RDD[Rating] = ALSDataHandler.als_data(spark).cache()
        // 将获得的Rating集合拆分按照0.2,0.8比例拆成两个集合
       // val Array(train,test) = alldata.randomSplit(Array(0.8,0.2))
        // 使用8成的数据去训练模型
        val model = new ALS().setCheckpointInterval(2).setRank(10).setIterations(20).setLambda(0.01).setImplicitPrefs(false)
          .run(alldata)
        // 对模型进行测试,每个用户推荐前30名商品
        val tj = model.recommendProductsForUsers(30)
        import spark.implicits._
        // (uid,gid,rank)
        val df5 = tj.flatMap{
          case(user:Int,ratings:Array[Rating])=>
            ratings.map{case (rat:Rating)=>(user,rat.product,rat.rating)}
        }.toDF("uid","gid","rank")
          // 还原成(cust_id,good_id,score)
          .join(goodstab,Seq("gid"),"inner")
          .join(custstab,Seq("uid"),"inner")
          .select($"cust_id",$"good_id",$"rank")
       //   .show(false)
        HDFSConnection.writeDataToHDFS("/myshops/dwd_ALS_Iter20",df5)
      }
    }

     

  • 相关阅读:
    The last access date is not changed even after reading the file on Windows 7
    渗透基础——获得当前系统已安装的程序列表
    Top 10 Best Free Netflow Analyzers and Collectors for Windows
    平安科技移动开发二队技术周报(第十五期)
    Intent传递对象的几种方式
    SQLite基础学习
    2 Java基础语法(keyword,标识符,凝视,常量,进制转换,变量,数据类型,数据类型转换)
    iOS 之应用性能调优的25个建议和技巧
    Fragment事务管理源代码分析
    CMake
  • 原文地址:https://www.cnblogs.com/sabertobih/p/13873782.html
Copyright © 2020-2023  润新知