• Spark MLlib 之 aggregate和treeAggregate从原理到应用


    在阅读spark mllib源码的时候,发现一个出镜率很高的函数——aggregate和treeAggregate,比如matrix.columnSimilarities()中。为了好好理解这两个方法的使用,于是整理了本篇内容。

    由于treeAggregate是在aggregate基础上的优化版本,因此先来看看aggregate是什么.

    更多内容参考我的大数据学习之路

    aggregate

    先直接看一下代码例子:

    import org.apache.spark.sql.SparkSession
    
    object AggregateTest {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
    	// 创建rdd,并分成6个分区
        val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
        // 输出每个分区的内容
        rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
          Array((s" $index : ${it.toList.mkString(",")}")).toIterator
        }).foreach(println)
        // 执行agg
        val res1 = rdd.aggregate(0)(seqOp, combOp)
      }
      // 分区内执行的方法,直接加和
      def seqOp(s1:Int, s2:Int):Int = {
        println("seq: "+s1+":"+s2)
        s1 + s2
      }
      // 在driver端汇总
      def combOp(c1: Int, c2: Int): Int = {
        println("comb: "+c1+":"+c2)
        c1 + c2
      }
    }
    

    这段代码的主要目的就是为了求和。考虑到spark分区并行计算的特性,在每个分区独立加和,最后再汇总加和。

    过程可以参考下面的图片:

    首先看一下map阶段,即在每个分区内计算加和。初始情况如蓝色方块所示,内容为:

    分区号:里面的内容
    如,0分区内的数据为6和8
    

    当执行seqop时,会说先用初始值0开始遍历累加,原理类似如下:

    rdd.mapPartitions((it:Iterator)=>{
    	var sum = init_value // 默认为0
    	it.foreach(sum + _)
    	sum
    })
    

    因此屏幕上会出现下面的内容,由于分区之间是并行的,所以最后的结果是乱序的:

    seq: 0:6
    seq: 0:1
    seq: 0:3
    seq: 1:9
    seq: 3:10
    seq: 0:2
    seq: 0:5
    seq: 5:7
    seq: 12:12
    seq: 0:4
    seq: 4:11
    seq: 6:8
    

    计算完成后,依次遍历每个分区结果,进行累加:

    comb: 0:10
    comb: 10:13
    comb: 23:2
    comb: 25:24
    comb: 49:15
    comb: 64:14
    

    aggregate的源码也比较简单:

    def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
        var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
        val cleanSeqOp = sc.clean(seqOp)
        val cleanCombOp = sc.clean(combOp)
        val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
        val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
        sc.runJob(this, aggregatePartition, mergeResult)
        jobResult
      }
    

    treeAggregate

    treeAggregate在aggregate的基础上做了一些优化,因为aggregate是在每个分区计算完成后,把所有的数据拉倒driver端,进行统一的遍历合并,这样如果数据量很大,在driver端可能会OOM。

    因此treeAggregate在中间多加了一层合并。

    先来看看代码,没有任何的变化:

    import org.apache.spark.sql.SparkSession
    
    object TreeAggregateTest {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local[*]").appName("tf-idf").getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
    
        val rdd = spark.sparkContext.parallelize(1 to 12).repartition(6)
        rdd.mapPartitionsWithIndex((index:Int,it:Iterator[Int])=>{
          Array(s" $index : ${it.toList.mkString(",")}").toIterator
        }).foreach(println)
    
        val res1 = rdd.treeAggregate(0)(seqOp, combOp)
        println(res1)
      }
    
      def seqOp(s1:Int, s2:Int):Int = {
        println("seq: "+s1+":"+s2)
        s1 + s2
      }
    
      def combOp(c1: Int, c2: Int): Int = {
        println("comb: "+c1+":"+c2)
        c1 + c2
      }
    }
    

    输出的结果则发生了变化,首先分区内的操作不变:

     3 : 3,10
     2 : 2
     0 : 6,8
     1 : 1,9
     4 : 4,11
     5 : 5,7,12
    seq: 0:3
    seq: 0:6
    seq: 3:10
    seq: 6:8
    seq: 0:2
    seq: 0:1
    seq: 1:9
    seq: 0:4
    seq: 4:11
    seq: 0:5
    seq: 5:7
    seq: 12:12
    ...
    

    在合并的时候发生了 变化:

    comb: 10:13
    comb: 23:24
    comb: 14:2
    comb: 16:15
    comb: 47:31
    

    配合下面的流程图,可以更好的理解:

    搭配treeAggregate的源码来看一下:

    def treeAggregate[U: ClassTag](zeroValue: U)(
          seqOp: (U, T) => U,
          combOp: (U, U) => U,
          depth: Int = 2): U = withScope {
        require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
        if (partitions.length == 0) {
          Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
        } else {
    	  // 这里都没什么变化,在分区中遍历数据累加
          val cleanSeqOp = context.clean(seqOp)
          val cleanCombOp = context.clean(combOp)
          val aggregatePartition =
            (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
          var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
    
          // 关键是这下面的内容 !!!!
          // 首先获得当前的分区数
          var numPartitions = partiallyAggregated.partitions.length
          // 计算合适的并行度,我这里相当于6^(1/2),也就是2.4左右,ceill向上取整后变成3.
          // max(3,2)得到最后的结果为3。即每个树的分枝有3个叶子节点
          val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
          
          // 遍历分区,通过对scale取模进行合并计算
          // 这里判断一下,当前的分区数是否还够分。如果少于条件值 scale+(p/scale),就停止分区
          while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
            numPartitions /= scale
            val curNumPartitions = numPartitions
            // 重新定义分区id,并按照分区id重新分区,执行合并计算
            partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
              (i, iter) => iter.map((i % curNumPartitions, _))
            }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
          }
    	  // 最后统计结果
          partiallyAggregated.reduce(cleanCombOp)
        }
      }
    

    spark中的应用

    // matrix求相似度
    def columnSimilarities(threshold: Double): CoordinateMatrix = {
    ...	             columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma)
    }
    // 统计每一个向量的相关数据,里面包含了min max 等等很多信息
    def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = {
      val summary = rows.treeAggregate(new MultivariateOnlineSummarizer)(
        (aggregator, data) => aggregator.add(data),
        (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
      updateNumRows(summary.count)
      summary
    }
    

    了解了treeAggregate之后,后续就可以看matrix的并行求解相似度的源码了!敬请期待吧...

    参考

  • 相关阅读:
    HDOJ骨头的诱惑
    DP Big Event in HDU
    hoj1078
    poj2728
    hoj1195
    poj2739
    poj2726
    海量并发也没那么可怕,运维准点下班全靠它!
    云上安全工作乱如麻,等保2.0来一下
    实践案例丨教你一键构建部署发布前端和Node.js服务
  • 原文地址:https://www.cnblogs.com/xing901022/p/9285898.html
Copyright © 2020-2023  润新知