• Spark Streaming updateStateByKey和mapWithState源码解密


    本篇从二个方面进行源码分析:

    一、updateStateByKey解密

    二、mapWithState解密

    通过对Spark研究角度来研究jvm、分布式、图计算、架构设计、软件工程思想,可以学到很多东西。

    进行黑名单动态生成和过滤例子中会用到updateStateByKey方法,此方法在DStream类中没有定义,需要在

    DStream的object区域通过隐式转换来找,如下面的代码:

    object DStream {
      // `toPairDStreamFunctions` was in SparkContext before 1.3 and users had to
      // `import StreamingContext._` to enable it. Now we move it here to make the compiler find
      // it automatically. However, we still keep the old function in StreamingContext for backward
      // compatibility and forward to the following function directly.
      implicit def toPairDStreamFunctions[K, V](stream: DStream[(K, V)])
          (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null):
        PairDStreamFunctions[K, V] = {
        new PairDStreamFunctions[K, V](stream)
      }

    继续跟踪PairDStreamFunctions类中有次方法定义:

    /**
     * Return a new "state" DStream where the state for each key is updated by applying
     * the given function on the previous state of the key and the new values of each key.
     * Hash partitioning is used to generate the RDDs with
    `numPartitions` partitions.
     * @param updateFunc State update function. If
    `this` function returns None, then
     * corresponding state key-value pair will be eliminated.
     * @param numPartitions Number of partitions of each RDD in the new DStream.
     * @tparam S State type
     */
    def updateStateByKey[S: ClassTag](
        updateFunc: (Seq[V], Option[S]) => Option[S],
        numPartitions: Int
      ): DStream[(K, S)] = ssc.withScope {
      updateStateByKey(updateFunc, defaultPartitioner(numPartitions))
    }
    继续返回DStream类:
    HashPartitioner的特点是效率高,spark1.2之前采用的主要目的是效率高,不需要排序之类的,设置并行度:
    private[streaming] def defaultPartitioner(numPartitions: Int = self.ssc.sc.defaultParallelism) = {
      new HashPartitioner(numPartitions)
    }
    /**
     * Return a new "state" DStream where the state for each key is updated by applying
     * the given function on the previous state of the key and the new values of each key.
     * org.apache.spark.Partitioner is used to control the partitioning of each RDD.
     * @param updateFunc State update function. Note, that this function may generate a different
     * tuple with a different key than the input key. Therefore keys may be removed
     * or added in this way. It is up to the developer to decide whether to
     * remember the partitioner despite the key being changed.
     * @param partitioner Partitioner for controlling the partitioning of each RDD in the new
     * DStream
     * @param rememberPartitioner Whether to remember the paritioner object in the generated RDDs.
     * @tparam S State type
     */
    def updateStateByKey[S: ClassTag](
        updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
        partitioner: Partitioner,
        rememberPartitioner: Boolean
      ): DStream[(K, S)] = ssc.withScope {
       new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
    }
    继续跟踪StateDStream,继承了DStream,如果对状态不断的操作就会产生很多的StateDStream状态对象:
    private[streaming]
    class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
        parent: DStream[(K, V)],
        updateFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)],
        partitioner: Partitioner,
        preservePartitioning: Boolean,
        initialRDD : Option[RDD[(K, S)]]
      ) extends DStream[(K, S)](parent.ssc) {
      super.persist(StorageLevel.MEMORY_ONLY_SER)

    看一段关键的代码:

    override def compute(validTime: Time): Option[RDD[(K, S)]] = {
      // Try to get the previous state RDD
      getOrCompute(validTime - slideDuration) match {
        case Some(prevStateRDD) => {    // If previous state RDD exists
          // Try to get the parent RDD
          parent.getOrCompute(validTime) match {
            case Some(parentRDD) => {   // If parent RDD exists, then compute as usual
              computeUsingPreviousRDD (parentRDD, prevStateRDD)
            }
            case None => {    // If parent RDD does not exist
              // Re-apply the update function to the old state RDD
              val updateFuncLocal = updateFunc
              val finalFunc = (iterator: Iterator[(K, S)]) => {
                val i = iterator.map(t => (t._1, Seq[V](), Option(t._2)))
                updateFuncLocal(i)
              }
    //效率角度
              val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning)
              Some(stateRDD)
            }
          }
        }
    根据代码分析,把函数传进来,看cogroup,按照key对value进行聚合,按照key对所有数据进行扫描然后聚合,这样做好处是对rdd的计算;

    不好的地方就是性能,cogroup对所有数据进行扫描,随着时间流逝数据规模越来越大性能越低,cogroup rdd和另一个

    cogroup rdd数据进行扫描合并。如下关键代码:

    private [this] def computeUsingPreviousRDD (
      parentRDD : RDD[(K, V)], prevStateRDD : RDD[(K, S)]) = {
      // Define the function for the mapPartition operation on cogrouped RDD;
      // first map the cogrouped tuple to tuples of required type,
      // and then apply the update function
      val updateFuncLocal = updateFunc
      val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
        val i = iterator.map(t => {
          val itr = t._2._2.iterator
          val headOption = if (itr.hasNext) Some(itr.next()) else None
          (t._1, t._2._1.toSeq, headOption)
        })
        updateFuncLocal(i)
      }
      val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner)
      val stateRDD = cogroupedRDD.mapPartitions(finalFunc, preservePartitioning)
      Some(stateRDD)
    }
    /**
     * For each key k in
    `this` or `other`, return a resulting RDD that contains a tuple with the
     * list of values for that key in
    `this` as well as `other`.
     */
    def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner)
        : RDD[(K, (Iterable[V], Iterable[W]))] = self.withScope {
      if (partitioner.isInstanceOf[HashPartitioner] && keyClass.isArray) {
        throw new SparkException("Default partitioner cannot partition array keys.")
      }
      val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
      cg.mapValues { case Array(vs, w1s) =>
        (vs.asInstanceOf[Iterable[V]], w1s.asInstanceOf[Iterable[W]])
      }
    }

    继续剖析mapWithState解密

    再看mapWithState,返回的是一个DStream,维护历史状态、更新历史状态都是基于key来维护,state相当于内存数据表,其实是在删除一张表,这张表中

    记录了历史状态,一张key、value、state的表,所有历史状态都放在这张表中,根据key 在satate的基础上更新value,如单词计数,不断累积计数:

    /**
     * :: Experimental ::
     * Return a
    [[MapWithStateDStream]] by applying a function to every key-value element of
     *
    `this` stream, while maintaining some state data for each unique key. The mapping function
     * and other specification (e.g. partitioners, timeouts, initial state data, etc.) of this
     * transformation can be specified using
    [[StateSpec]] class. The state data is accessible in
     * as a parameter of type
    [[State]] in the mapping function.
     *
     * Example of using
    `mapWithState`:
     *
    {{{
     *    // A mapping function that maintains an integer state and return a String
     *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
     *      // Use state.exists(), state.get(), state.update() and state.remove()
     *      // to manage state, and return the necessary string
     *    }
     *
     *    val spec = StateSpec.function(mappingFunction).numPartitions(10)
     *
     *    val mapWithStateDStream = keyValueDStream.mapWithState[StateType, MappedType](spec)
     *
    }}}
     *
     * @param spec          Specification of this transformation
     * @tparam StateType    Class type of the state data
     * @tparam MappedType   Class type of the mapped data
     */
    @Experimental
    def mapWithState[StateType: ClassTag, MappedType: ClassTag](
        spec: StateSpec[K, V, StateType, MappedType]
      ): MapWithStateDStream[K, V, StateType, MappedType] = {
      new MapWithStateDStreamImpl[K, V, StateType, MappedType](
        self,
        spec.asInstanceOf[StateSpecImpl[K, V, StateType, MappedType]]
      )
    }
    内存数据表都会有:defined、timingOut、updated、removed:
    /**
     * :: Experimental ::
     * Abstract class for getting and updating the state in mapping function used in the
    `mapWithState`
     * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala)
     * or a
    [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java).
     *
     * Scala example of using
    `State`:
     *
    {{{
     *    // A mapping function that maintains an integer state and returns a String
     *    def mappingFunction(key: String, value: Option[Int], state: State[Int]): Option[String] = {
     *      // Check if state exists
     *      if (state.exists) {
     *        val existingState = state.get  // Get the existing state
     *        val shouldRemove = ...         // Decide whether to remove the state
     *        if (shouldRemove) {
     *          state.remove()     // Remove the state
     *        } else {
     *          val newState = ...
     *          state.update(newState)    // Set the new state
     *        }
     *      } else {
     *        val initialState = ...
     *        state.update(initialState)  // Set the initial state
     *      }
     *      ... // return something
     *    }
     *
     *
    }}}
    /** Internal implementation of the [[State]] interface */
    private[streaming] class StateImpl[S] extends State[S] {
      private var state: S = null.asInstanceOf[S]
      private var defined: Boolean = false
      private var
    timingOut: Boolean = false
      private var
    updated: Boolean = false
      private var
    removed: Boolean = false
     
    // ========= Public API =========
      override def exists(): Boolean = {
        defined
     
    }
      override def get(): S = {
        if (defined) {
          state
       
    } else {
          throw new NoSuchElementException("State is not set")
        }
      }
      override def update(newState: S): Unit = {
        require(!removed, "Cannot update the state after it has been removed")
        require(!timingOut, "Cannot update the state that is timing out")
        state = newState
        defined = true
       
    updated = true
     
    }
    下面的代码V就是外面传入的函数:
    /** Internal implementation of [[org.apache.spark.streaming.StateSpec]] interface. */
    private[streaming]
    case class StateSpecImpl[K, V, S, T](
     function: (Time, K, Option[V], State[S]) => Option[T]) extends StateSpec[K, V, S, T] {
     /** Internal implementation of the [[MapWithStateDStream]] */
    private[streaming] class MapWithStateDStreamImpl[
        KeyType: ClassTag, ValueType: ClassTag, StateType: ClassTag, MappedType: ClassTag](
        dataStream: DStream[(KeyType, ValueType)],
        spec: StateSpecImpl[KeyType, ValueType, StateType, MappedType])
      extends MapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream.context) {
      private val internalStream =
        new InternalMapWithStateDStream[KeyType, ValueType, StateType, MappedType](dataStream, spec)
      override def slideDuration: Duration = internalStream.slideDuration
      override def dependencies: List[DStream[_]] = List(internalStream)
      override def compute(validTime: Time): Option[RDD[MappedType]] = {
        internalStream.getOrCompute(validTime).map { _.flatMap[MappedType] { _.mappedData } }
      }

    基于历史数据的更新,有内存数据结构,更新已有数据结构,而不是在已有的基础上创建内存数据结构:

    /**
     * A DStream that allows per-key state to be maintains, and arbitrary records to be generated
     * based on updates to the state. This is the main DStream that implements the
    `mapWithState`
     * operation on DStreams.
     *
     * @param parent Parent (key, value) stream that is the source
     * @param spec Specifications of the mapWithState operation
     * @tparam K   Key type
     * @tparam V   Value type
     * @tparam S   Type of the state maintained
     * @tparam E   Type of the mapped data
     */
    private[streaming]
    class InternalMapWithStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
        parent: DStream[(K, V)], spec: StateSpecImpl[K, V, S, E])
      extends DStream[MapWithStateRDDRecord[K, S, E]](parent.context) {
      persist(StorageLevel.MEMORY_ONLY)

    基于时间窗口创建一个新rdd,是所有故事下面开始:

    Some(new MapWithStateRDD(
      prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
    ** Method that generates a RDD for the given time */
      override def compute(validTime: Time): Option[RDD[MapWithStateRDDRecord[K, S, E]]] = {
        // Get the previous state or create a new empty state RDD
        val prevStateRDD = getOrCompute(validTime - slideDuration) match {
          case Some(rdd) =>
            if (rdd.partitioner != Some(partitioner)) {
              // If the RDD is not partitioned the right way, let us repartition it using the
              // partition index as the key. This is to ensure that state RDD is always partitioned
              // before creating another state RDD using it
              MapWithStateRDD.createFromRDD[K, V, S, E](
                rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
            } else {
              rdd
            }
          case None =>
            MapWithStateRDD.createFromPairRDD[K, V, S, E](
              spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
              partitioner,
              validTime
            )
        }
        // Compute the new state RDD with previous state RDD and partitioned data RDD
        // Even if there is no data RDD, use an empty one to create a new state RDD
        val dataRDD = parent.getOrCompute(validTime).getOrElse {
          context.sparkContext.emptyRDD[(K, V)]
        }
        val partitionedDataRDD = dataRDD.partitionBy(partitioner)
        val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
          (validTime - interval).milliseconds
        }
        Some(new MapWithStateRDD(
          prevStateRDD, partitionedDataRDD, mappingFunction, validTime, timeoutThresholdTime))
      }
    }

    看下MapWithStateRDD:

    /**
     * RDD storing the keyed states of
    `mapWithState` operation and corresponding mapped data.
     * Each partition of this RDD has a single record of type
    [[MapWithStateRDDRecord]]. This contains a
     *
    [[StateMap]] (containing the keyed-states) and the sequence of records returned by the mapping
     * function of 
    `mapWithState`.
     * @param prevStateRDD The previous MapWithStateRDD on whose StateMap data
    `this` RDD
      *                    will be created
     * @param partitionedDataRDD The partitioned data RDD which is used update the previous StateMaps
     *                           in the
    `prevStateRDD` to create `this` RDD
     * @param mappingFunction  The function that will be used to update state and return new data
     * @param batchTime        The time of the batch to which this RDD belongs to. Use to update
     * @param timeoutThresholdTime The time to indicate which keys are timeout
     */
    private[streaming] class MapWithStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
        private var prevStateRDD: RDD[MapWithStateRDDRecord[K, S, E]],
        private var partitionedDataRDD: RDD[(K, V)],
        mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
        batchTime: Time,
        timeoutThresholdTime: Option[Long]
      ) extends RDD[MapWithStateRDDRecord[K, S, E]](
        partitionedDataRDD.sparkContext,
        List(
          new OneToOneDependency[MapWithStateRDDRecord[K, S, E]](prevStateRDD),
          new OneToOneDependency(partitionedDataRDD))
      ) {

    每个partition被一个MapWithStateRDDRecord代表的,里面有一个数据结构stateMap,再看此类的重点compute方法:

    override def compute(
      partition: Partition, context: TaskContext): Iterator[MapWithStateRDDRecord[K, S, E]] = {
      val stateRDDPartition = partition.asInstanceOf[MapWithStateRDDPartition]
      val prevStateRDDIterator = prevStateRDD.iterator(
        stateRDDPartition.previousSessionRDDPartition, context)
      val dataIterator = partitionedDataRDD.iterator(
        stateRDDPartition.partitionedDataRDDPartition, context)
      val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
      val newRecord = MapWithStateRDDRecord.updateRecordWithData(
        prevRecord,
        dataIterator,
        mappingFunction,
        batchTime,
        timeoutThresholdTime,
        removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
      )
      Iterator(newRecord)
    }

     private[streaming] object MapWithStateRDDRecord {

      def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
        prevRecord: Option[MapWithStateRDDRecord[K, S, E]],
        dataIterator: Iterator[(K, V)],
        mappingFunction: (Time, K, Option[V], State[S]) => Option[E],
        batchTime: Time,
        timeoutThresholdTime: Option[Long],
        removeTimedoutData: Boolean
      ): MapWithStateRDDRecord[K, S, E] = {
        // Create a new state map by cloning the previous one (if it exists) or by creating an empty one
        val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }
        val mappedData = new ArrayBuffer[E]
        val wrappedState = new StateImpl[S]()
        // Call the mapping function on each record in the data iterator, and accordingly
        // update the states touched, and collect the data returned by the mapping function
        dataIterator.foreach { case (key, value) =>
          wrappedState.wrap(newStateMap.get(key))
          val returned = mappingFunction(batchTime, key, Some(value), wrappedState)
          if (wrappedState.isRemoved) {
            newStateMap.remove(key)
          } else if (wrappedState.isUpdated
              || (wrappedState.exists && timeoutThresholdTime.isDefined)) {
            newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
          }
          mappedData ++= returned
        }
    // Get the timed out state records, call the mapping function on each and collect the
      // data returned
      if (removeTimedoutData && timeoutThresholdTime.isDefined) {
        newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
          wrappedState.wrapTimingOutState(state)
          val returned = mappingFunction(batchTime, key, None, wrappedState)
          mappedData ++= returned
          newStateMap.remove(key)
        }
      }
      MapWithStateRDDRecord(newStateMap, mappedData)
    }
    wrappedState是可以不断被赋值的,mappedData代表最后返回的值。根据当前batch的数据进行计算,更新了newStateMap的数据结构,保存了历史数据,

    没有对历史数据进行计算或遍历,只会进行更新、插入操作。Record代表一个partition,MapWithStateRDDRecord中record记录并没改变。

    DStream操作RDD,RDD内部变了。所以不可变的rdd可以处理变化的rdd。

    Spark Streaming发行版笔记14

    新浪微博:http://weibo.com/ilovepains

    微信公众号:DT_Spark

    博客:http://blog.sina.com.cn/ilovepains

    手机:18610086859

    QQ:1740415547

    邮箱:18610086859@vip.126.com

     
  • 相关阅读:
    UVa 1374
    天梯赛L3 004
    redis操作ZSet
    redis操作set集合
    mybatis使用注解开发
    SSM整合之mybatis的别名配置
    mybatis的5.1.10分页插件的使用
    lombok的使用
    JDBC的一个简单工具类
    mybatis的测试
  • 原文地址:https://www.cnblogs.com/sparkbigdata/p/5544447.html
Copyright © 2020-2023  润新知