• Spark SQL 源代码分析之Physical Plan 到 RDD的详细实现


      /** Spark SQL源代码分析系列文章*/

      接上一篇文章Spark SQL Catalyst源代码分析之Physical Plan。本文将介绍Physical Plan的toRDD的详细实现细节:

      我们都知道一段sql,真正的运行是当你调用它的collect()方法才会运行Spark Job,最后计算得到RDD。
      lazy val toRdd: RDD[Row] = executedPlan.execute()

      Spark Plan基本包括4种操作类型,即BasicOperator基本类型,还有就是Join、Aggregate和Sort这样的稍复杂的。

      如图:

      

    一、BasicOperator

    1.1、Project

      Project 的大致含义是:传入一系列表达式Seq[NamedExpression],给定输入的Row。经过Convert(Expression的计算eval)操作。生成一个新的Row。
      Project的实现是调用其child.execute()方法,然后调用mapPartitions对每一个Partition进行操作。
      这个f函数事实上是new了一个MutableProjection,然后循环的对每一个partition进行Convert。
    case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
      override def output = projectList.map(_.toAttribute)
      override def execute() = child.execute().mapPartitions { iter => //对每一个分区进行f映射
        @transient val reusableProjection = new MutableProjection(projectList) 
        iter.map(reusableProjection)
      }
    }
      通过观察MutableProjection的定义,能够发现。就是bind references to a schema 和 eval的过程:
      将一个Row转换为还有一个已经定义好schema column的Row。
      假设输入的Row已经有Schema了,则传入的Seq[Expression]也会bound到当前的Schema。

    case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
      def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
        this(expressions.map(BindReferences.bindReference(_, inputSchema))) //bound schema
    
      private[this] val exprArray = expressions.toArray
      private[this] val mutableRow = new GenericMutableRow(exprArray.size) //新的Row
      def currentValue: Row = mutableRow
      def apply(input: Row): Row = {
        var i = 0
        while (i < exprArray.length) {
          mutableRow(i) = exprArray(i).eval(input)  //依据输入的input,即一个Row,计算生成的Row
          i += 1
        }
        mutableRow //返回新的Row
      }
    }

    1.2、Filter

     Filter的详细实现是传入的condition进行对input row的eval计算。最后返回的是一个Boolean类型,
     假设表达式计算成功。返回true,则这个分区的这条数据就会保存下来,否则会过滤掉。
    case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
      override def output = child.output
    
      override def execute() = child.execute().mapPartitions { iter =>
        iter.filter(condition.eval(_).asInstanceOf[Boolean]) //计算表达式 eval(input row)
      }
    }

    1.3、Sample

      Sample取样操作事实上是调用了child.execute()的结果后,返回的是一个RDD,对这个RDD调用其sample函数,原生方法。
    case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
      extends UnaryNode
    {
      override def output = child.output
    
      // TODO: How to pick seed?
      override def execute() = child.execute().sample(withReplacement, fraction, seed)
    }

    1.4、Union

      Union操作支持多个子查询的Union,所以传入的child是一个Seq[SparkPlan]
      execute()方法的实现是对其全部的children,每一个进行execute()。即select查询的结果集合RDD。

      通过调用SparkContext的union方法。将全部子查询的结果合并起来。
    case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan {
      // TODO: attributes output by union should be distinct for nullability purposes
      override def output = children.head.output
      override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) //子查询的结果进行union
    
      override def otherCopyArgs = sqlContext :: Nil
    }

    1.5、Limit

      Limit操作在RDD的原生API里也有。即take().
      可是Limit的实现分2种情况:
      第一种是 limit作为结尾的操作符,即select xxx from yyy limit zzz。 而且是被executeCollect调用,则直接在driver里使用take方法。
      另外一种是 limit不是作为结尾的操作符。即limit后面还有查询,那么就在每一个分区调用limit,最后repartition到一个分区来计算global limit.
    case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext)
      extends UnaryNode {
      // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
      // partition local limit -> exchange into one partition -> partition local limit again
    
      override def otherCopyArgs = sqlContext :: Nil
    
      override def output = child.output
    
      override def executeCollect() = child.execute().map(_.copy()).take(limit) //直接在driver调用take
    
      override def execute() = {
        val rdd = child.execute().mapPartitions { iter =>
          val mutablePair = new MutablePair[Boolean, Row]()
          iter.take(limit).map(row => mutablePair.update(false, row)) //每一个分区先计算limit
        }
        val part = new HashPartitioner(1)
        val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part) //须要shuffle,来repartition
        shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
        shuffled.mapPartitions(_.take(limit).map(_._2)) //最后单独一个partition来take limit
      }
    }

    1.6、TakeOrdered

      TakeOrdered是经过排序后的limit N,通常是用在sort by 操作符后的limit。
      能够简单理解为TopN操作符。
    case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
                          (@transient sqlContext: SQLContext) extends UnaryNode {
      override def otherCopyArgs = sqlContext :: Nil
    
      override def output = child.output
    
      @transient
      lazy val ordering = new RowOrdering(sortOrder) //这里是通过RowOrdering来实现排序的
    
      override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
    
      // TODO: Terminal split should be implemented differently from non-terminal split.
      // TODO: Pick num splits based on |limit|.
      override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
    }

    1.7、Sort

      Sort也是通过RowOrdering这个类来实现排序的,child.execute()对每一个分区进行map,每一个分区依据RowOrdering的order来进行排序,生成一个新的有序集合。

      也是通过调用Spark RDD的sorted方法来实现的。

    case class Sort(
        sortOrder: Seq[SortOrder],
        global: Boolean,
        child: SparkPlan)
      extends UnaryNode {
      override def requiredChildDistribution =
        if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
    
      @transient
      lazy val ordering = new RowOrdering(sortOrder) //排序顺序
    
      override def execute() = attachTree(this, "sort") {
        // TODO: Optimize sorting operation?
        child.execute()
          .mapPartitions(
            iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, //每一个分区调用sorted方法,传入<span style="font-family: Arial, Helvetica, sans-serif;">ordering排序规则,进行排序</span>
            preservesPartitioning = true)
      }
    
      override def output = child.output
    }

    1.8、ExistingRdd

    ExistingRdd是
    object ExistingRdd {
      def convertToCatalyst(a: Any): Any = a match {
        case o: Option[_] => o.orNull
        case s: Seq[Any] => s.map(convertToCatalyst)
        case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
        case other => other
      }
    
      def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
        data.mapPartitions { iterator =>
          if (iterator.isEmpty) {
            Iterator.empty
          } else {
            val bufferedIterator = iterator.buffered
            val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
    
            bufferedIterator.map { r =>
              var i = 0
              while (i < mutableRow.length) {
                mutableRow(i) = convertToCatalyst(r.productElement(i))
                i += 1
              }
    
              mutableRow
            }
          }
        }
      }
    
      def fromProductRdd[A <: Product : TypeTag](productRdd: RDD[A]) = {
        ExistingRdd(ScalaReflection.attributesFor[A], productToRowRdd(productRdd))
      }
    }

    二、 Join Related Operators

      HashJoin:

      在解说Join Related Operator之前。有必要了解一下HashJoin这个位于execution包下的joins.scala文件中的trait。

      Join操作主要包括BroadcastHashJoinLeftSemiJoinHashShuffledHashJoin均实现了HashJoin这个trait.
      主要类图例如以下:
      
      
      HashJoin这个trait的主要成员有:
      buildSide是左连接还是右连接,有一种基准的意思。
      leftKeys是左孩子的expressions, rightKeys是右孩子的expressions。
      left是左孩子物理计划,right是右孩子物理计划。

      buildSideKeyGenerator是一个Projection是依据传入的Row对象来计算buildSide的Expression的。
      streamSideKeyGenerator是一个MutableProjection是依据传入的Row对象来计算streamSide的Expression的。
      这里buildSide假设是left的话,能够理解为buildSide是左表,那么去连接这个左表的右表就是streamSide。

      
      HashJoin关键的操作是joinIterators。简单来说就是join两个表。把每一个表看着Iterators[Row].
      方式:
      1、首先遍历buildSide,计算buildKeys然后利用一个HashMap,形成 (buildKeys, Iterators[Row])的格式。
      2、遍历StreamedSide。计算streamedKey,去HashMap里面去匹配key,来进行join
      3、最后生成一个joinRow,这个将2个row对接。

      见代码凝视:
    trait HashJoin {
      val leftKeys: Seq[Expression]
      val rightKeys: Seq[Expression]
      val buildSide: BuildSide
      val left: SparkPlan
      val right: SparkPlan
    
      lazy val (buildPlan, streamedPlan) = buildSide match {  //模式匹配,将physical plan封装形成Tuple2,假设是buildLeft。那么就是(left,right),否则是(right,left)
        case BuildLeft => (left, right)
        case BuildRight => (right, left)
      }
    
      lazy val (buildKeys, streamedKeys) = buildSide match { //模式匹配,将expression进行封装<span style="font-family: Arial, Helvetica, sans-serif;">Tuple2</span>
    
        case BuildLeft => (leftKeys, rightKeys)
        case BuildRight => (rightKeys, leftKeys)
      }
    
      def output = left.output ++ right.output
    
      @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) //生成buildSideKey来依据Expression来计算Row返回结果
      @transient lazy val streamSideKeyGenerator = //<span style="font-family: Arial, Helvetica, sans-serif;">生成</span><span style="font-family: Arial, Helvetica, sans-serif;">streamSideKeyGenerator</span><span style="font-family: Arial, Helvetica, sans-serif;">来依据Expression来计算Row返回结果</span>
        () => new MutableProjection(streamedKeys, streamedPlan.output)
    
      def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = { //把build表的Iterator[Row]和streamIterator[Row]进行join操作返回Join后的Iterator[Row]
        // TODO: Use Spark's HashMap implementation.
    
        val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]() //匹配主要使用HashMap实现
        var currentRow: Row = null
    
        // Create a mapping of buildKeys -> rows 
        while (buildIter.hasNext) { //眼下仅仅对build Iterator进行迭代,形成rowKey,Rows,相似wordCount,可是这里不是累加Value,而是Row的集合。

    currentRow = buildIter.next() val rowKey = buildSideKeyGenerator(currentRow) //计算rowKey作为HashMap的key if(!rowKey.anyNull) { val existingMatchList = hashTable.get(rowKey) val matchList = if (existingMatchList == null) { val newMatchList = new ArrayBuffer[Row]() hashTable.put(rowKey, newMatchList) //(rowKey, matchedRowList) newMatchList } else { existingMatchList } matchList += currentRow.copy() //返回matchList } } new Iterator[Row] { //最后用streamedRow的Key来匹配buildSide端的HashMap private[this] var currentStreamedRow: Row = _ private[this] var currentHashMatches: ArrayBuffer[Row] = _ private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. private[this] val joinRow = new JoinedRow private[this] val joinKeys = streamSideKeyGenerator() override final def hasNext: Boolean = (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) || (streamIter.hasNext && fetchNext()) override final def next() = { val ret = buildSide match { case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) //右连接的话,streamedRow放左边。匹配到的key的Row放到右表 case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) //左连接的话,相反。

    } currentMatchPosition += 1 ret } /** * Searches the streamed iterator for the next row that has at least one match in hashtable. * * @return true if the search is successful, and false if the streamed iterator runs out of * tuples. */ private final def fetchNext(): Boolean = { currentHashMatches = null currentMatchPosition = -1 while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { currentHashMatches = hashTable.get(joinKeys.currentValue) //streamedRow从buildSide里的HashTable里面匹配rowKey } } if (currentHashMatches == null) { false } else { currentMatchPosition = 0 true } } } } }

    joinRow的实现,实现2个Row对接:
    实际上就是生成一个新的Array,将2个Array合并。

    class JoinedRow extends Row {
      private[this] var row1: Row = _
      private[this] var row2: Row = _
      .........
       def copy() = {
        val totalSize = row1.size + row2.size 
        val copiedValues = new Array[Any](totalSize)
        var i = 0
        while(i < totalSize) {
          copiedValues(i) = apply(i)
          i += 1
        }
        new GenericRow(copiedValues) //返回一个新的合并后的Row
      }

    2.1、LeftSemiJoinHash

     left semi join,不多说了。hive早期版本号里替代IN和EXISTS 的版本号。
     将右表的join keys放到HashSet里。然后遍历左表,查找左表的join key能否匹配。
    case class LeftSemiJoinHash(
        leftKeys: Seq[Expression],
        rightKeys: Seq[Expression],
        left: SparkPlan,
        right: SparkPlan) extends BinaryNode with HashJoin {
    
      val buildSide = BuildRight //buildSide是以右表为基准
    
      override def requiredChildDistribution =
        ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
    
      override def output = left.output
    
      def execute() = {
        buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => //右表的物理计划运行后生成RDD,利用zipPartitions对Partition进行合并。然后用上述方法实现。
          val hashSet = new java.util.HashSet[Row]()
          var currentRow: Row = null
    
          // Create a Hash set of buildKeys
          while (buildIter.hasNext) {
            currentRow = buildIter.next()
            val rowKey = buildSideKeyGenerator(currentRow)
            if(!rowKey.anyNull) {
              val keyExists = hashSet.contains(rowKey)
              if (!keyExists) {
                hashSet.add(rowKey)
              }
            }
          }
    
          val joinKeys = streamSideKeyGenerator()
          streamIter.filter(current => {
            !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
          })
        }
      }
    }

    2.2、BroadcastHashJoin

     名约: 广播HashJoin,呵呵。
      是InnerHashJoin的实现。这里用到了concurrent并发里的future,异步的广播buildPlan的表运行后的的RDD。

      假设接收到了广播后的表,那么就用streamedPlan来匹配这个广播的表。

      实现是RDD的mapPartitions和HashJoin里的joinIterators最后生成join的结果。
    case class BroadcastHashJoin(
         leftKeys: Seq[Expression],
         rightKeys: Seq[Expression],
         buildSide: BuildSide,
         left: SparkPlan,
         right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {
    
      override def otherCopyArgs = sqlContext :: Nil
    
      override def outputPartitioning: Partitioning = left.outputPartitioning
    
      override def requiredChildDistribution =
        UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
    
      @transient
      lazy val broadcastFuture = future {  //利用SparkContext广播表
        sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
      }
    
      def execute() = {
        val broadcastRelation = Await.result(broadcastFuture, 5.minute)
    
        streamedPlan.execute().mapPartitions { streamedIter =>
          joinIterators(broadcastRelation.value.iterator, streamedIter) //调用joinIterators对每一个分区map
        }
      }
    }

    2.3、ShuffleHashJoin

    ShuffleHashJoin顾名思义就是须要shuffle数据,outputPartitioning是左孩子的的Partitioning。

    会依据这个Partitioning进行shuffle。

    然后利用SparkContext里的zipPartitions方法对每一个分区进行zip。

    这里的requiredChildDistribution。的是ClusteredDistribution,这个会在HashPartitioning里面进行匹配。

    关于这里面的分区这里不赘述,能够去org.apache.spark.sql.catalyst.plans.physical下的partitioning里面去查看。
    case class ShuffledHashJoin(
        leftKeys: Seq[Expression],
        rightKeys: Seq[Expression],
        buildSide: BuildSide,
        left: SparkPlan,
        right: SparkPlan) extends BinaryNode with HashJoin {
    
      override def outputPartitioning: Partitioning = left.outputPartitioning
    
      override def requiredChildDistribution =
        ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
    
      def execute() = {
        buildPlan.execute().zipPartitions(streamedPlan.execute()) {
          (buildIter, streamIter) => joinIterators(buildIter, streamIter)
        }
      }
    }


    未完待续 :)

    原创文章,转载请注明:

    转载自:OopsOutOfMemory盛利的Blog。作者: OopsOutOfMemory

    本文链接地址:http://blog.csdn.net/oopsoom/article/details/38274621

    注:本文基于署名-非商业性使用-禁止演绎 2.5 中国大陆(CC BY-NC-ND 2.5 CN)协议。欢迎转载、转发和评论。可是请保留本文作者署名和文章链接。如若须要用于商业目的或者与授权方面的协商,请联系我。

    image


  • 相关阅读:
    c++ 单步查看汇编代码【转】
    c++ 类内部函数调用虚函数
    grep和sed替换文件中的字符串【转】
    vim 正则替换【转】
    linux 文件编码问题
    shell截取字符串的一些简单方法
    chrome 安装页面编码选择插件
    namespace main
    【转】c++ 多线程
    使用git提交到github,每次都要输入用户名和密码的解决方法
  • 原文地址:https://www.cnblogs.com/cxchanpin/p/6869232.html
Copyright © 2020-2023  润新知