• Spark MLlib FPGrowth关联规则算法


    一.简介

      FPGrowth算法是关联分析算法,它采取如下分治策略:将提供频繁项集的数据库压缩到一棵频繁模式树(FP-tree),但仍保留项集关联信息。在算法中使用了一种称为频繁模式树(Frequent Pattern Tree)的数据结构。FP-tree是一种特殊的前缀树,由频繁项头表和项前缀树构成。

      相关术语:

        1.项与项集

          这是一个集合的概念,以购物车为例,一件商品就是一项【item】,若干项的集合为项集,如{特步鞋,安踏运动服}为一个二元项集。

        2.关联规则

          关联规则用于表示数据内隐含的关联性,例如买了新鞋的客户也往往会买袜子。

        3.支持度

          支持度是指在所有项集中{x,y}出现的可能性,即项集中同时出现含有x和y的概率。该指标作为建立强关联规则的第一个门槛,衡量了所考察关联规则在“量”上的多少。

        4.置信度

          表示在先决条件x发生的情况下,关联结果y发生的概率。这是生成强关联规则的第二个门槛,衡量了所考察的关联规则在“质”上的可靠性。

        5.提升度

          表示在含有x的条件下同时含有y的可能性与没有x的条件下项集含有y的可能性之比。

    二.测试数据 

    r z h k p
    z y x w v u t s
    s x o n r
    x z y m t s q e
    z
    x z y r q t p

    三.代码实现 

    package big.data.analyse.mllib
    
    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.mllib.fpm.FPGrowth
    import org.apache.spark.{SparkContext, SparkConf}
    
    /**
      * 关联规则
      * Created by zhen on 2019/4/11.
      */
    object FPG {
      Logger.getLogger("org").setLevel(Level.WARN)
      def main(args: Array[String]) {
        val conf = new SparkConf()
        conf.setAppName("fpg")
        conf.setMaster("local[2]")
    
        val sc = new SparkContext(conf)
    
        /**
          * 加载数据
          */
        val data = sc.textFile("data/mllib/sample_fpgrowth.txt")
        val data_spl = data.map(row => row.split(" ")).cache()
    
        /**
          * 创建模型
          */
        val minSupport = 0.2
        val numPartition = 10
        val model = new FPGrowth()
          .setMinSupport(minSupport)
          .setNumPartitions(numPartition)
          .run(data_spl)
    
        /**
          * 打印结果
          */
        println("Number of frequent itemsets : " + model.freqItemsets.count())
        model.freqItemsets.collect.foreach{itemset =>
          println(itemset.items.mkString("[", ",", "]") + " ==> " + itemset.freq)
        }
      }
    }

    四.结果

       .......

    五.精简测试数据

      y z

      z y x

      x

      x z y

      z

      x z

    六.二次开发代码实现

    package big.data.analyse.mllib
    
    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.mllib.fpm.FPGrowth
    import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
    import org.apache.spark.sql.{Row, SQLContext}
    import org.apache.spark.{SparkContext, SparkConf}
    
    /**
      * 关联规则
      * Created by zhen on 2019/4/11.
      */
    object FPG {
      Logger.getLogger("org").setLevel(Level.WARN)
      def main(args: Array[String]) {
        val conf = new SparkConf()
        conf.setAppName("fpg")
        conf.setMaster("local[2]")
    
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
    
        /**
          * 加载数据
          */
        val data = sc.textFile("data/mllib/sample_fpgrowth.txt")
        val data_spl = data.map(row => row.split(" ")).cache()
    
        /**
          * 创建模型
          */
        val minSupport = 0.2
        val numPartition = 10
        val model = new FPGrowth()
          .setMinSupport(minSupport)
          .setNumPartitions(numPartition)
          .run(data_spl)
    
        /**
          * 打印结果
          */
        //println("Number of frequent itemsets : " + model.freqItemsets.count())
        model.freqItemsets.collect.foreach{itemset =>
          println(itemset.items.mkString("[", "-", "]") + " ==> " + itemset.freq)
        }
    
        /**
          * 把结果数据转换为Map
          */
        val map = model.freqItemsets
          .map{row =>
            var map : Map[String,Double] = Map()
            map += (row.items.mkString("-") -> row.freq.toDouble)
            map
          }.collect().flatten.toMap
    
        val list = map.keysIterator.toList
    
        /**
          * 拆分比较,计算概率
          */
        var mid_result : Map[String, Double] = Map()
    
        for(i <- 0 until list.length){
          for(j <- 0 until list.length){
            if(i != j){
              if(list(i).contains(list(j))){  // xy -> xyz
                var key = ""
                if(list(i).indexOf(list(j)) == 0){ // 子串位于母串开头
                  key = list(j) + "_" + list(i).replace(list(j) + "-", "")
                }else{// 子串位于母串的中间或者末尾
                  key = list(j) + "_" + list(i).replace("-" + list(j), "")
                }
                val left = map(list(j))
                val right = map(list(i))
                val value = right / left
                mid_result += (key -> value)
              }else{// TODO 分开包含的也要加进行,比较顺序不一定一致,例如:xy -> xzy
                val left_key = list(i).split("-")
                val right_key = list(j).split("-")
                var isno = true
                for(x <- 0 until right_key.length){
                  if(!left_key.contains(right_key(x))){
                    isno = false
                  }
                }
                if(isno){ // 包含
                  var mid_key = "" // 拼接key
                  for(y <- 0 until left_key.length){
                    if(!right_key.contains(left_key(y))){
                      mid_key += left_key(y) + "-"
                    }
                  }
                  if(mid_key != ""){ // 清除末尾多余的-
                    mid_key = mid_key.substring(0, mid_key.length-1)
                  }
                  val key = list(j) + "_" + mid_key
                  val left = map(list(j))
                  val right = map(list(i))
                  val value = right / left
                  mid_result += (key -> value)
                }
              }
            }
          }
        }
    
        /**
          *平衡标签先后顺序对概率的影响
          */
        var result : List[String] = List()
        val keys = mid_result.keysIterator.toList
        for(i <- 0 until keys.length){
          println(keys(i) +":"+ mid_result(keys(i)))
        }
        for(i <- 0 until keys.length){
          for(j <- 0 until keys.length){
            if(i != j){
              val left = keys(i).split("_")
              val right = keys(j).split("_")
              if(left(0) == right(1) && left(1) == right(0)){
                val value = ((mid_result(keys(i)) + mid_result(keys(j)))/2).formatted("%.2f") // 保留两位小数
                if(left(0) < left(1)){
                  result = result.:+(left(0) + "_" + left(1) + "_" + value)
                }else{
                  result = result.:+(left(1) + "_" + left(0) + "_" + value)
                }
              }
            }
          }
        }
        result = result.distinct // 去重
        /*for(i <- 0 until result.length){
          println(result(i))
        }*/
    
        /**
          * 转换为rdd
          */
        val result_rdd = sc.parallelize(result).map(row => {
          val Array(left, right, probability) = row.split("_")
          Row(left, right, probability.toDouble)
        })
    
        /**
          * 定义结构
          */
        val structType = new StructType(Array(
          StructField("left", StringType, true),
          StructField("right", StringType, true),
          StructField("probability", DoubleType, true)
        ))
    
        val result_df = sqlContext.createDataFrame(result_rdd, structType)
    
        import org.apache.spark.sql.functions._
        result_df.orderBy(desc("probability")).show()
      }
    }

    七.结果

      

      

      

    八.备注

      集群模式出现以下异常【local模式无异常】;

        can not set final scala.collection.mutable.ListBuffer field org.apache.spark.mllib.fpm.FPTree$Summary.nodes to scala.collection.mutable.ArrayBuffer

       解决方案:

         配置:conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")

  • 相关阅读:
    异常、中断、陷阱
    BigDecimal
    事务
    jsp的九大内置对象
    timer和ScheduledThreadPoolExecutor
    关于Python的导入覆盖解决办法
    RTTI
    Golang不会自动把slice转换成interface{}类型的slice
    Python中下划线的5种含义
    Python如何合并两个字典
  • 原文地址:https://www.cnblogs.com/yszd/p/10691990.html
Copyright © 2020-2023  润新知