• Apache Spark-1.0.0浅析(十一):Shuffle过程


    一、Shuffle的产生

    Shuffle Dependency是划分stages的依据,由此判断是ShuffleMapStage或ResultStage,正如下所述

    * A Spark job consists of one or more stages. The very last stage in a job consists of multiple
    * ResultTasks, while earlier stages consist of ShuffleMapTasks. A ResultTask executes the task
    * and sends the task output back to the driver application. A ShuffleMapTask executes the task
    * and divides the task output to multiple buckets (based on the task's partitioner).

    Shuffle是MapReduce框架中的必要环节,它是连接Map和Reduce的桥梁。Shuffle只可能产生于值为[k, v]的PairedRDD的操作中,其他RDD是不会产生Shuffle的。当Map的输出结果要被Reduce使用时,输出结果需要按key哈希,并且分发到每一个Reducer上去,这个过程就是shuffle。Shuffle过程涉及到磁盘的读写和网络的传输,因此shuffle性能的高低直接影响到了整个程序的运行效率。正因如此,shuffle是Spark调优,更普遍来说是MapReduce框架调优的关键。

    二、Shuffle写入

    《Task执行》中最后提到,ShuffleMapTask与ResultTask的runTask实现是不一样的,主要区别在于中间计算结果是否write。下面分几个主要部分分析ShuffleMapTask.runTask

    (I)定义变量

    首先定义了4个变量:numOutputSplit、BlockManager、ShuffleBlockManager和Shuffle。numOutputSplits是partition的数量;通过SparkEnv获取blockManager;通过blockManager定义shuffleBlockManager;定义Shuffle为shuffleWriterGroup类型。

    val numOutputSplits = dep.partitioner.numPartitions
    
    val blockManager = SparkEnv.get.blockManager
    val shuffleBlockManager = blockManager.shuffleBlockManager
    var shuffle: ShuffleWriterGroup = null

    ShuffleBlockManager的类定义如下。如注释所述,该类将基于磁盘的block writer分配给shuffle任务。每个shuffle任务获得一个文件/reducer, 这个文件集被称为ShuffleFileGroup。为了减少shuffle文件产生数量,多个shuffle blocks累积到同一个文件。当任务完成shuffle文件的写入时,立即释放该文件让另外的task占用。

    Shuffle文件由三元组(shuffleId,bucketId,fileId)唯一标记。每个shuffle文件映射到一个Filesegment,同样也是一个三元组(file,offset,length),指明实际block数据在给定文件中的位置。Shuffle文件以高效空间方式存储,每个ShuffleFileGroup为每个文件中存储的每个block维护一个偏移列表。要找到shuffle block的位置,在与block reducer相关的ShuffleMapGroup中搜索。

    /**
     * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file
     * per reducer (this set of files is called a ShuffleFileGroup).
     *
     * As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
     * blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
     * per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle
     * files, it releases them for another task.
     * Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
     *   - shuffleId: The unique id given to the entire shuffle stage.
     *   - bucketId: The id of the output partition (i.e., reducer id)
     *   - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
     *       time owns a particular fileId, and this id is returned to a pool when the task finishes.
     * Each shuffle file is then mapped to a FileSegment, which is a 3-tuple (file, offset, length)
     * that specifies where in a given file the actual block data is located.
     *
     * Shuffle file metadata is stored in a space-efficient manner. Rather than simply mapping
     * ShuffleBlockIds directly to FileSegments, each ShuffleFileGroup maintains a list of offsets for
     * each block stored in each file. In order to find the location of a shuffle block, we search the
     * files within a ShuffleFileGroups associated with the block's reducer.
     */
    private[spark]
    class ShuffleBlockManager(blockManager: BlockManager) extends Logging {

    ShuffleWriterGroup声明如下,为ShuffleMapTask定义了一组writer,每个reducer一个writer

    /** A group of writers for a ShuffleMapTask, one writer per reducer. */
    private[spark] trait ShuffleWriterGroup {

    (II)获取shuffle writer

    为shuffle blocks获取所有的block writers,首先获得序列化器,然后shuffleBlockManager调用forMapTask根据(shuffleId = shuffleId,mapId = partitionId,numBuckets = numOutputSplits)获取shuffle writer

    // Obtain all the block writers for shuffle blocks.
    val ser = Serializer.getSerializer(dep.serializer)
    shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)

    查看ShuffleBlockManager.forMapTask,可以发现writers其实是BlockObjectWriter数组

    def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
        new ShuffleWriterGroup {
          shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
          private val shuffleState = shuffleStates(shuffleId)
          private var fileGroup: ShuffleFileGroup = null
    
          val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
            fileGroup = getUnusedFileGroup()
            Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
              val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
              blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize)
            }
          } else {
            Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
              val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
              val blockFile = blockManager.diskBlockManager.getFile(blockId)
              // Because of previous failures, the shuffle file may already exist on this machine.
              // If so, remove it.
              if (blockFile.exists) {
                if (blockFile.delete()) {
                  logInfo(s"Removed existing shuffle file $blockFile")
                } else {
                  logWarning(s"Failed to remove existing shuffle file $blockFile")
                }
              }
              blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize)
            }
          }

    关于consolidateShuffleFiles选项

    // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
    // TODO: Remove this once the shuffle file consolidation feature is stable.
    val consolidateShuffleFiles =
      conf.getBoolean("spark.shuffle.consolidateFiles", false)

    该选项如果打开,则首先获取UnusedFileGroup,如果已经存在fileGroup返回,没有则创建

    private def getUnusedFileGroup(): ShuffleFileGroup = {
          val fileGroup = shuffleState.unusedFileGroups.poll()
          if (fileGroup != null) fileGroup else newFileGroup()
    }

    对于每一个三元组(shuffleId,mapId,bucketId)确定的bucketId创建blockId,而相同的bucketId使用同一个fileGroup中的不同文件,即要发送到同一个reduce的数据写入到同一个文件,如此生成的bucket数量等于Reducer。fileGroup其实调用apply方法,取bucketId对应的文件

    def apply(bucketId: Int) = files(bucketId)

    如果关闭,所有shuffle blocks写入单独的文件,同样三元组(shuffleId,mapId,bucketId)确定一个blockId,以blockId作为参数,根据blockId调用blockManager.diskBlockManager.getFile得到blockFile,在磁盘空间中创建目录文件,即按照blockId生成文件,如此会创建的bucket数量则为Mapper*Reduer。getFile调用以blockId.name为参数的同名方法

    def getFile(blockId: BlockId): File = getFile(blockId.name)

    最终调用的getFile如下

    def getFile(filename: String): File = {
        // Figure out which local directory it hashes to, and which subdirectory in that
        val hash = Utils.nonNegativeHash(filename)
        val dirId = hash % localDirs.length
        val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
    
        // Create the subdirectory if it doesn't already exist
        var subDir = subDirs(dirId)(subDirId)
        if (subDir == null) {
          subDir = subDirs(dirId).synchronized {
            val old = subDirs(dirId)(subDirId)
            if (old != null) {
              old
            } else {
              val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
              newDir.mkdir()
              subDirs(dirId)(subDirId) = newDir
              newDir
            }
          }
        }

    ShuffleBlockId是一个case class,它定义了shuffle writer写入的文件名

    case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
      extends BlockId {
      def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
    }

    最后注意一下blockManager.getDiskWriter,最后一个参数buffersize默认100kb,这直接影响shuffle过程占用的内存空间大小

    private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024

    创建文件、获取DiskWriter完成后,Shuffle的中间结果都需要落入磁盘中

    (III)写入buckets

    遍历RDD的所有partitions,将每个元素转换成(K,V)格式,计算得到bucketId,最后将(K,V)通过bucketId对应writer写入bucket中

    // Write the map output to its associated buckets.
    for (elem <- rdd.iterator(split, context)) {
      val pair = elem.asInstanceOf[Product2[Any, Any]]
      val bucketId = dep.partitioner.getPartition(pair._1)
      shuffle.writers(bucketId).write(pair)
    }

    key-value pair逐个写入磁盘文件中,不用预先把所有数据存储在内存中再整体flush到磁盘。

    write的定义为BlockObjectWriter.write

    /**
       * Writes an object.
       */
      def write(value: Any)

    具体实现为DiskBlockObjectWriter.write

    override def write(value: Any) {
        if (!initialized) {
          open()
        }
        objOut.writeObject(value)
      }

    (IV)执行

    注意写入分区中的数据大小是用Byte表示的数组,这就需要compressSize方法

    // Commit the writes. Get the size of each bucket block (total block size).
    var totalBytes = 0L
    var totalTime = 0L
    val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
      writer.commit()
      writer.close()
      val size = writer.fileSegment().length
      totalBytes += size
      totalTime += writer.timeWriting()
      MapOutputTracker.compressSize(size)
    }

    compressSize,使用1.1为底数的指数,将28映射成1.1256,支持至少35GB大小,切误差只有10%,非常巧妙

    /**
       * Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
       * We do this by encoding the log base 1.1 of the size as an integer, which can support
       * sizes up to 35 GB with at most 10% error.
       */
      def compressSize(size: Long): Byte = {
        if (size == 0) {
          0
        } else if (size <= 1L) {
          1
        } else {
          math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
        }
      }

    (V)更新监控

    // Update shuffle metrics.
    val shuffleMetrics = new ShuffleWriteMetrics
    shuffleMetrics.shuffleBytesWritten = totalBytes
    shuffleMetrics.shuffleWriteTime = totalTime
    metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
    
    success = true
    new MapStatus(blockManager.blockManagerId, compressedSizes)

    (VI)异常处理

        catch { case e: Exception =>
          // If there is an exception from running the task, revert the partial writes
          // and throw the exception upstream to Spark.
          if (shuffle != null && shuffle.writers != null) {
            for (writer <- shuffle.writers) {
              writer.revertPartialWrites()
              writer.close()
            }
          }
          throw e
        } finally {
          // Release the writers back to the shuffle block manager.
          if (shuffle != null && shuffle.writers != null) {
            try {
              shuffle.releaseWriters(success)
            } catch {
              case e: Exception => logError("Failed to release shuffle writers", e)
            }
          }

    (VII)成功回调

    // Execute the callbacks on task completion.
    context.executeOnCompleteCallbacks()

    三、Shuffle读取

    (I)记录MapOutputs

    《结果返回》中提到,当ShuffleMapTask执行完成时,调用handleTaskCompletion处理后续过程

    /**
     * Responds to a task finishing. This is called inside the event loop so it assumes that it can
     * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
     */
    private[scheduler] def handleTaskCompletion(event: CompletionEvent) {

    handleTaskCompletion定义中,专门定义了ShuffleMapTask成功完成时的响应

              case smt: ShuffleMapTask =>
                val status = event.result.asInstanceOf[MapStatus]
                val execId = status.location.executorId
                logDebug("ShuffleMapTask finished on " + execId)
                if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
                  logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
                } else {
                  stage.addOutputLoc(smt.partitionId, status)
                }
                if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) {
                  markStageAsFinished(stage)
                  logInfo("looking for newly runnable stages")
                  logInfo("running: " + runningStages)
                  logInfo("waiting: " + waitingStages)
                  logInfo("failed: " + failedStages)
                  if (stage.shuffleDep.isDefined) {
                    // We supply true to increment the epoch number here in case this is a
                    // recomputation of the map outputs. In that case, some nodes may have cached
                    // locations with holes (from when we detected the error) and will need the
                    // epoch incremented to refetch them.
                    // TODO: Only increment the epoch number if this is not the first time
                    //       we registered these map outputs.
                    mapOutputTracker.registerMapOutputs(
                      stage.shuffleDep.get.shuffleId,
                      stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                      changeEpoch = true)
                  }
                  clearCacheLocs()
                  if (stage.outputLocs.exists(_ == Nil)) {
                    // Some tasks had failed; let's resubmit this stage
                    // TODO: Lower-level scheduler should also deal with this
                    logInfo("Resubmitting " + stage + " (" + stage.name +
                      ") because some of its tasks had failed: " +
                      stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", "))
                    submitStage(stage)
                  } else {
                    val newlyRunnable = new ArrayBuffer[Stage]
                    for (stage <- waitingStages) {
                      logInfo("Missing parents for " + stage + ": " + getMissingParentStages(stage))
                    }
                    for (stage <- waitingStages if getMissingParentStages(stage) == Nil) {
                      newlyRunnable += stage
                    }
                    waitingStages --= newlyRunnable
                    runningStages ++= newlyRunnable
                    for {
                      stage <- newlyRunnable.sortBy(_.id)
                      jobId <- activeJobForStage(stage)
                    } {
                      logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
                      submitMissingTasks(stage, jobId)
                    }
                  }
                }
              }

    ShuffleMapTask成功完成后,调用stage.addOutputLoc

    stage.addOutputLoc(smt.partitionId, status)

    把Map返回的MapStatus添加到stage的outputLoc中

    def addOutputLoc(partition: Int, status: MapStatus) {
        val prevList = outputLocs(partition)
        outputLocs(partition) = status :: prevList
        if (prevList == Nil) {
          numAvailableOutputs += 1
        }
      }

    outputLocs是一个MapStatus类型的List

    val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)

    MapStatus中记录着输出结果的相关信息,为了将其传递到对应reduce任务,其中包含了BlockManagerId和为每个reducer输出数据大小(经过压缩)

    /**
     * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
     * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
     * The map output sizes are compressed using MapOutputTracker.compressSize.
     */
    private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte])
      extends Externalizable {

    条件stage.shuffleDep.isDefined定义如下,判断如果是Mapper则执行操作,如果为Reducer则跳过

    val shuffleDep: Option[ShuffleDependency[_,_]],  // Output shuffle if stage is a map stage

    如果所有的shuffle的task都执行完成,调用registerMapOutputs,把此stage对应的shuffled与所有的location注册到mapOutputTracker中

    mapOutputTracker.registerMapOutputs(
                      stage.shuffleDep.get.shuffleId,
                      stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray,
                      changeEpoch = true)

    MapOutputTrackerMaster.registerMapOutputs定义如下

    /** Register multiple map output information for the given shuffle */
      def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
        mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
        if (changeEpoch) {
          incrementEpoch()
        }
      }

    (II)获取

    《Task执行》中则提到,RDD.iterator判断RDD是否cached,调用getOrCompute还是computeOrReadCheckpoint

    final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
        if (storageLevel != StorageLevel.NONE) {
          SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
        } else {
          computeOrReadCheckpoint(split, context)
        }
      }

    computeOrReadCheckpoint调用compute计算结果

    private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
      {
        if (isCheckpointed) firstParent[T].iterator(split, context) else compute(split, context)
      }

    ShuffleRDD对compute的实现如下,这也是读取ShuffleMapTask计算结果的入口

    override def compute(split: Partition, context: TaskContext): Iterator[P] = {
        val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
        val ser = Serializer.getSerializer(serializer)
        SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
      }

    下面来看fetch,其作用是获取Shuffle Map的输出,有四个参数:shuffleId、reduceId、context和serializer,返回一个iterator遍历shuffle outputs的所有元素

    /**
       * Fetch the shuffle outputs for a given ShuffleDependency.
       * @return An iterator over the elements of the fetched shuffle outputs.
       */
      def fetch[T](
          shuffleId: Int,
          reduceId: Int,
          context: TaskContext,
          serializer: Serializer = SparkEnv.get.serializer): Iterator[T]

    具体实现是BlockStoreShuffleFetcher.fetch

    override def fetch[T](
          shuffleId: Int,
          reduceId: Int,
          context: TaskContext,
          serializer: Serializer)
        : Iterator[T] =
      {
    
        logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
        val blockManager = SparkEnv.get.blockManager
    
        val startTime = System.currentTimeMillis
        val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
        logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
          shuffleId, reduceId, System.currentTimeMillis - startTime))
    
        val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
        for (((address, size), index) <- statuses.zipWithIndex) {
          splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
        }
    
        val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
          case (address, splits) =>
            (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
        }
    
        def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
          val blockId = blockPair._1
          val blockOption = blockPair._2
          blockOption match {
            case Some(block) => {
              block.asInstanceOf[Iterator[T]]
            }
            case None => {
              blockId match {
                case ShuffleBlockId(shufId, mapId, _) =>
                  val address = statuses(mapId.toInt)._1
                  throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
                case _ =>
                  throw new SparkException(
                    "Failed to get block " + blockId + ", which is not a shuffle block")
              }
            }
          }
        }
    
        val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
        val itr = blockFetcherItr.flatMap(unpackBlock)
    
        val completionIter = CompletionIterator[T, Iterator[T]](itr, {
          val shuffleMetrics = new ShuffleReadMetrics
          shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
          shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
          shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
          shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
          shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
          shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
          context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
        })
    
        new InterruptibleIterator[T](context, completionIter)
      }

    接下来,分析该方法

    (1)调用mapOutputTracker.getServerStatuses使worker获取master的URIs和map输出的大小,即之前存储的MapStatus信息

    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)

    getServerStatuses实现如下,注意这个方法是executor调用的,根据shuffleId和reduceId,返回BlockManagerId和一个Long型数字表示的map输出大小,一个BlockManagerId对应多个文件的大小

    /**
       * Called from executors to get the server URIs and output sizes of the map outputs of
       * a given shuffle.
       */
      def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
        val statuses = mapStatuses.get(shuffleId).orNull
        if (statuses == null) {
          logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
          var fetchedStatuses: Array[MapStatus] = null
          fetching.synchronized {
            if (fetching.contains(shuffleId)) {
              // Someone else is fetching it; wait for them to be done
              while (fetching.contains(shuffleId)) {
                try {
                  fetching.wait()
                } catch {
                  case e: InterruptedException =>
                }
              }
            }
    
            // Either while we waited the fetch happened successfully, or
            // someone fetched it in between the get and the fetching.synchronized.
            fetchedStatuses = mapStatuses.get(shuffleId).orNull
            if (fetchedStatuses == null) {
              // We have to do the fetch, get others to wait for us.
              fetching += shuffleId
            }
          }
    
          if (fetchedStatuses == null) {
            // We won the race to fetch the output locs; do so
            logInfo("Doing the fetch; tracker actor = " + trackerActor)
            // This try-finally prevents hangs due to timeouts:
            try {
              val fetchedBytes =
                askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
              fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
              logInfo("Got the output locations")
              mapStatuses.put(shuffleId, fetchedStatuses)
            } finally {
              fetching.synchronized {
                fetching -= shuffleId
                fetching.notifyAll()
              }
            }
          }
          if (fetchedStatuses != null) {
            fetchedStatuses.synchronized {
              return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
            }
          } else {
            throw new FetchFailedException(null, shuffleId, -1, reduceId,
              new Exception("Missing all output locations for shuffle " + shuffleId))
          }
        } else {
          statuses.synchronized {
            return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
          }
        }
      }

    最后调用convertMapStatus转换MapStatus

    // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
      // any of the statuses is null (indicating a missing location due to a failed mapper),
      // throw a FetchFailedException.
      private def convertMapStatuses(
            shuffleId: Int,
            reduceId: Int,
            statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
        assert (statuses != null)
        statuses.map {
          status =>
            if (status == null) {
              throw new FetchFailedException(null, shuffleId, -1, reduceId,
                new Exception("Missing an output location for shuffle " + shuffleId))
            } else {
              (status.location, decompressSize(status.compressedSizes(reduceId)))
            }
        }
      }

    Long型的输出大小是decompressSize后的结果

    /**
       * Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
       */
      def decompressSize(compressedSize: Byte): Long = {
        if (compressedSize == 0) {
          0
        } else {
          math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
        }
      }

    (2)构造BlockManagerId 和 BlockId的映射关系,创建HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]],获取或者更新元素,将其转换成(BlockManagerId,ShuffleBlockId,Size)三元组,其中的ShuffleBlockId就是index,而ShuffleBlockId是ShuffleBlockId(shuffleId, mapId, bucketId)组合得到的,mapId为Index=1,2,3……

      val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
      for (((address, size), index) <- statuses.zipWithIndex) {
        splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
      }
    
      val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
        case (address, splits) =>
          (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
      }

    (3)定义校验函数unpackBlock,若BlockId对应一个Iterator则返回,若没有则抛出异常

    def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
          val blockId = blockPair._1
          val blockOption = blockPair._2
          blockOption match {
            case Some(block) => {
              block.asInstanceOf[Iterator[T]]
            }
            case None => {
              blockId match {
                case ShuffleBlockId(shufId, mapId, _) =>
                  val address = statuses(mapId.toInt)._1
                  throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
                case _ =>
                  throw new SparkException(
                    "Failed to get block " + blockId + ", which is not a shuffle block")
              }
            }
          }
        }

    接下来调用BlockManager.getMultiple从本地或者远端block manager获得多个blocks,并使用unpackBlock校验返回Iterator。

    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
    val itr = blockFetcherItr.flatMap(unpackBlock)

    getMultiple方法,根据是否使用netty,分成BasicBlockFetcherIterator和NettyBlockFetcherIterator。

    /**
      * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns
      * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined
      * fashion as they're received. Expects a size in bytes to be provided for each block fetched,
      * so that we can control the maxMegabytesInFlight for the fetch.
      */
      def getMultiple(
          blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
          serializer: Serializer): BlockFetcherIterator = {
        val iter =
          if (conf.getBoolean("spark.shuffle.use.netty", false)) {
            new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
          } else {
            new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
          }
    
        iter.initialize()
        iter
      }

    先看BasicBlockFetcherIterator

    初始化initialize,首先划分local和remote的blocks,将remote blocks以随机顺序放入请求序列,发送获取请求,最多不超过maxByteInFlight,并在remote blocks返回结果的同时,获取local blocks。

      override def initialize() {
          // Split local and remote blocks.
          val remoteRequests = splitLocalRemoteBlocks()
          // Add the remote requests into our queue in a random order
          fetchRequests ++= Utils.randomize(remoteRequests)
    
          // Send out initial requests for blocks, up to our maxBytesInFlight
          while (!fetchRequests.isEmpty &&
            (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
            sendRequest(fetchRequests.dequeue())
          }
    
          val numFetches = remoteRequests.size - fetchRequests.size
          logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
    
          // Get Local Blocks
          startTime = System.currentTimeMillis
          getLocalBlocks()
          logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
        }

    下面分析几个重要的方法

    1. splitLocalRemoteBlocks

    最多同时从5个节点并行读取数据,每次请求的数据不会超过spark.reducer.maxMbInFlight / 5;通过blocksByAddress中的BlockManagerId与本地BlockManagerId对比,判断是否local blocks,如果是local,过滤掉0大小的block,将BlockInfos中的BlockId记录到localBlocksToFetch中,累计block fetch的大小;如果是remote,也过滤掉0大小的block,通过Iterator遍历blocks,将blockId添加到remoteBlocksToFetch,size累计到curRequestSize中,如果curRequestSize刚超过targetRequestSize,则立即创建remote fetch request,如果遍历最后有剩余size,则将最后部分作为一个request,最后返回remoteRequests。

    protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
      // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
      // nodes, rather than blocking on reading output from one node.
      val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
      logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
    
      // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
      // at most maxBytesInFlight in order to limit the amount of data in flight.
      val remoteRequests = new ArrayBuffer[FetchRequest]
      for ((address, blockInfos) <- blocksByAddress) {
        if (address == blockManagerId) {
          numLocal = blockInfos.size
          // Filter out zero-sized blocks
          localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
          _numBlocksToFetch += localBlocksToFetch.size
        } else {
          numRemote += blockInfos.size
          val iterator = blockInfos.iterator
          var curRequestSize = 0L
          var curBlocks = new ArrayBuffer[(BlockId, Long)]
          while (iterator.hasNext) {
            val (blockId, size) = iterator.next()
            // Skip empty blocks
            if (size > 0) {
              curBlocks += ((blockId, size))
              remoteBlocksToFetch += blockId
              _numBlocksToFetch += 1
              curRequestSize += size
            } else if (size < 0) {
              throw new BlockException(blockId, "Negative block size " + size)
            }
            if (curRequestSize >= targetRequestSize) {
              // Add this FetchRequest
              remoteRequests += new FetchRequest(address, curBlocks)
              curRequestSize = 0
              curBlocks = new ArrayBuffer[(BlockId, Long)]
              logDebug(s"Creating fetch request of $curRequestSize at $address")
            }
          }
          // Add in the final request
          if (!curBlocks.isEmpty) {
            remoteRequests += new FetchRequest(address, curBlocks)
          }
        }
      }
      logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
        totalBlocks + " blocks")
      remoteRequests
    }

    maxBytesInFlight大小定义如下,最大48MB,限制正在获取和需要发送请求

    // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory
    // for receiving shuffle outputs)
    val maxBytesInFlight =
      conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024

    2. sendRequest

    通过ConnectionManager建立连接,然后sendMessageReliably,检验返回消息,根据blockId等信息,调用dataDeserialize将存储size的bytebuffer转换成Iterator。

    protected def sendRequest(req: FetchRequest) {
        logDebug("Sending request for %d blocks (%s) from %s".format(
          req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
        val cmId = new ConnectionManagerId(req.address.host, req.address.port)
        val blockMessageArray = new BlockMessageArray(req.blocks.map {
          case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
        })
        bytesInFlight += req.size
        val sizeMap = req.blocks.toMap  // so we can look up the size of each blockID
        val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
        future.onSuccess {
          case Some(message) => {
            val bufferMessage = message.asInstanceOf[BufferMessage]
            val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
            for (blockMessage <- blockMessageArray) {
              if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
                throw new SparkException(
                  "Unexpected message " + blockMessage.getType + " received from " + cmId)
              }
              val blockId = blockMessage.getId
              val networkSize = blockMessage.getData.limit()
              results.put(new FetchResult(blockId, sizeMap(blockId),
                () => dataDeserialize(blockId, blockMessage.getData, serializer)))
              _remoteBytesRead += networkSize
              logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
            }
          }
          case None => {
            logError("Could not get block(s) from " + cmId)
            for ((blockId, size) <- req.blocks) {
              results.put(new FetchResult(blockId, -1, null))
            }
          }
        }
      }

    3. getLocalBlocks

    注释中说明,之所以可以与remote blocks并行获取,是因为local blocks获取时只是内存映射到某些文件,不实际消耗网络资源(48MB上限)

    遍历localBlocksToFetch,getLocalFromDisk实际调用diskStore.getValues依据blockId直接从磁盘读取数据,返回Iterator。

    protected def getLocalBlocks() {
          // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
          // these all at once because they will just memory-map some files, so they won't consume
          // any memory that might exceed our maxBytesInFlight
          for (id <- localBlocksToFetch) {
            getLocalFromDisk(id, serializer) match {
              case Some(iter) => {
                // Pass 0 as size since it's not in flight
                results.put(new FetchResult(id, 0, () => iter))
                logDebug("Got local block " + id)
              }
              case None => {
                throw new BlockException(id, "Could not get block " + id + " from local machine")
              }
            }
          }
        }

    再看NettyBlockFetcherIterator

    初始化initialize,同样调用splitLocalRemoteBlocks划分local和remote blocks,随机顺序获取请求,启动copiers拷贝remote blocks,设定并行拷贝进程数为6个,获取local blocks。

        override def initialize() {
          // Split Local Remote Blocks and set numBlocksToFetch
          val remoteRequests = splitLocalRemoteBlocks()
          // Add the remote requests into our queue in a random order
          for (request <- Utils.randomize(remoteRequests)) {
            fetchRequestsSync.put(request)
          }
    
          copiers = startCopiers(conf.getInt("spark.shuffle.copier.threads", 6))
          logInfo("Started " + fetchRequestsSync.size + " remote fetches in " +
            Utils.getUsedTimeMs(startTime))
    
          // Get Local Blocks
          startTime = System.currentTimeMillis
          getLocalBlocks()
          logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
        }

    copiers其实是进程列表

    private var copiers: List[_ <: Thread] = null

    startCopiers实现如下,关键在于NettyBlockFetcherIterator类中重新实现的sendRequest。

    private def startCopiers(numCopiers: Int): List[_ <: Thread] = {
          (for ( i <- Range(0,numCopiers) ) yield {
            val copier = new Thread {
              override def run(){
                try {
                  while(!isInterrupted && !fetchRequestsSync.isEmpty) {
                    sendRequest(fetchRequestsSync.take())
                  }
                } catch {
                  case x: InterruptedException => logInfo("Copier Interrupted")
                  // case _ => throw new SparkException("Exception Throw in Shuffle Copier")
                }
              }
            }
            copier.start
            copier
          }).toList
        }

    NettyBlockFetcherIterator.sendRequest,创建ShuffleCopier,调用ShuffleCopier.getBlocks获得blocks。

    override protected def sendRequest(req: FetchRequest) {
    
          def putResult(blockId: BlockId, blockSize: Long, blockData: ByteBuf) {
            val fetchResult = new FetchResult(blockId, blockSize,
              () => dataDeserialize(blockId, blockData.nioBuffer, serializer))
            results.put(fetchResult)
          }
    
          logDebug("Sending request for %d blocks (%s) from %s".format(
            req.blocks.size, Utils.bytesToString(req.size), req.address.host))
          val cmId = new ConnectionManagerId(req.address.host, req.address.nettyPort)
          val cpier = new ShuffleCopier(blockManager.conf)
          cpier.getBlocks(cmId, req.blocks, putResult)
          logDebug("Sent request for remote blocks " + req.blocks + " from " + req.address.host )
        }

    getBlocks最终调用getBlock,创建FileClient,发送请求,从文件中获取blocks,具体工作由netty完成。

    def getBlock(host: String, port: Int, blockId: BlockId,
          resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
    
      val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
      val connectTimeout = conf.getInt("spark.shuffle.netty.connect.timeout", 60000)
      val fc = new FileClient(handler, connectTimeout)
    
      try {
        fc.init()
        fc.connect(host, port)
        fc.sendRequest(blockId.name)
        fc.waitForClose()
        fc.close()
      } catch {
        // Handle any socket-related exceptions in FileClient
        case e: Exception => {
          logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
          handler.handleError(blockId)
        }
      }
    }

    整个shuffle write + fetch过程分析完毕。

    Reference:

    [1] http://jerryshao.me/architecture/2014/01/04/spark-shuffle-detail-investigation/

    [2] http://www.uml.org.cn/sjjm/201411104.asp?artid=15468

    END

  • 相关阅读:
    接口自动化1-基础知识
    pytest-fixture之conftest.py
    测试人员一定要懂的ADB操作,赶紧来看一看~
    必看!利用装饰器,帮你自动处理异常并优雅实现重跑case
    最全Airtest接口功能介绍和示例总结,新手同学千万不能错过呀!(二)
    总结一波 Redis 面试题,收藏起来!
    IntelliJ IDEA 2020.2.4款 神级超级牛逼插件推荐
    华为 Java 开发编程军规,谁违反谁走
    CTO:再写if-else,逮着罚款1000!
    VSCode 上竟然也能约会,谈对象了???
  • 原文地址:https://www.cnblogs.com/kevingu/p/4902701.html
Copyright © 2020-2023  润新知