• Spark Shuffle模块——Suffle Read过程分析


    在阅读本文之前。请先阅读Spark Sort Based Shuffle内存分析

    Spark Shuffle Read调用栈例如以下:
    1. org.apache.spark.rdd.ShuffledRDD#compute()
    2. org.apache.spark.shuffle.ShuffleManager#getReader()
    3. org.apache.spark.shuffle.hash.HashShuffleReader#read()
    4. org.apache.spark.storage.ShuffleBlockFetcherIterator#initialize()
    5. org.apache.spark.storage.ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()
    org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()
    org.apache.spark.storage.ShuffleBlockFetcherIterator#fetchLocalBlocks()

    以下是fetchLocalBlocks()方法运行时涉及到的类和相应方法:
    6. org.apache.spark.storage.BlockManager#getBlockData()
    org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()
    ShuffleManager有两个子类。假设是HashShuffle 则相应的是org.apache.spark.shuffle.hash.HashShuffleManager#shuffleBlockResolver()方法,该方法返回的是org.apache.spark.shuffle.FileShuffleBlockResolver。再调用FileShuffleBlockResolver#getBlockData()方法返回Block数据
    ;假设是Sort Shuffle,则相应的是
    org.apache.spark.shuffle.hash.SortShuffleManager#shuffleBlockResolver(),该方法返回的是org.apache.spark.shuffle.IndexShuffleBlockResolver。然后再调用IndexShuffleBlockResolver#getBlockData()返回Block数据。

    以下是org.apache.spark.storage.ShuffleBlockFetcherIterator#sendRequest()方法运行时涉及到的类和相应方法
    7.

    org.apache.spark.network.shuffle.ShuffleClient#fetchBlocks
    org.apache.spark.network.shuffle.ShuffleClient有两个子类,各自是ExternalShuffleClient及BlockTransferService
    。其中org.apache.spark.network.shuffle.BlockTransferService又有两个子类,各自是NettyBlockTransferService和NioBlockTransferService。相应两种不同远程获取Block数据方式。Spark 1.5.2中已经将NioBlockTransferService方式设置为deprecated。在兴许版本号中将被移除

    以下按上述调用栈对各方法进行说明,这里仅仅讲脉络,细节后面再讨论

    ShuffledRDD#compute()代码

    Task运行时。调用ShuffledRDD的compute方法,其代码例如以下:

    //org.apache.spark.rdd.ShuffledRDD#compute()
    override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
        val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
        //通过org.apache.spark.shuffle.ShuffleManager#getReader()方法
        //不管是Sort Shuffle 还是 Hash Shuffle。使用的都是
        //org.apache.spark.shuffle.hash.HashShuffleReader
        SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
          .read()
          .asInstanceOf[Iterator[(K, C)]]
      }

    能够看到,其核心逻辑是通过调用ShuffleManager#getReader()方法得到HashShuffleReader对象。然后调用HashShuffleReader#read()方法完毕前一Stage中ShuffleMapTask生成的Shuffle 数据的读取。须要说明的是,不管是Hash Shuffle还是Sort Shuffle。使用的都是HashShuffleReader。

    HashShuffleReader#read()

    跳到HashShuffleReader#read()方法其中。其源代码例如以下:

    /** Read the combined key-values for this reduce task */
      override def read(): Iterator[Product2[K, C]] = {
        //创建ShuffleBlockFetcherIterator对象,在其构造函数中会调用initialize()方法
        //该方法中会运行splitLocalRemoteBlocks(),确定数据的读取策略
        //远程数据调用sendRequest()方法读取
        //本地数据调用fetchLocalBlocks()方法读取
        val blockFetcherItr = new ShuffleBlockFetcherIterator(
          context,
          blockManager.shuffleClient,
          blockManager,
          mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
          // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
          SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
    
        // Wrap the streams for compression based on configuration
        val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
          blockManager.wrapForCompression(blockId, inputStream)
        }
    
        val ser = Serializer.getSerializer(dep.serializer)
        val serializerInstance = ser.newInstance()
    
        // Create a key/value iterator for each stream
        val recordIter = wrappedStreams.flatMap { wrappedStream =>
          // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
          // NextIterator. The NextIterator makes sure that close() is called on the
          // underlying InputStream when all records have been read.
          serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
        }
    
        // Update the context task metrics for each record read.
        val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
        val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
          recordIter.map(record => {
            readMetrics.incRecordsRead(1)
            record
          }),
          context.taskMetrics().updateShuffleReadMetrics())
    
        // An interruptible iterator must be used here in order to support task cancellation
        val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
    
        val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
          if (dep.mapSideCombine) { 
            // 读取Map端已经聚合的数据
            val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
            dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
          } else {
            //读取Reducer端聚合的数据
            val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
            dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
          }
        } else {
          require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
          interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
        }
    
        // 对输出结果进行排序
        dep.keyOrdering match {
          case Some(keyOrd: Ordering[K]) =>
            // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
            // the ExternalSorter won't spill to disk.
            val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
            sorter.insertAll(aggregatedIter)
            context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
            context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
            context.internalMetricsToAccumulators(
              InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
            sorter.iterator
          case None =>
            aggregatedIter
        }
      }

    ShuffleBlockFetcherIterator#splitLocalRemoteBlocks()

    splitLocalRemoteBlocks()方法确定数据的读取策略,localBlocks变量记录在本地机器的BlockID,remoteBlocks变量则用于记录全部在远程机器上的BlockID。

    远程数据块被切割成最大为maxSizeInFlight大小的FetchRequests

    val remoteRequests = new ArrayBuffer[FetchRequest]

    splitLocalRemoteBlocks()方法具有源代码例如以下:

    private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
        // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
        // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
        // nodes, rather than blocking on reading output from one node.
        //maxBytesInFlight为每次请求的最大数据量,默认值为48M
        //通过SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)进行设置
        val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
        logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
    
        // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
        // at most maxBytesInFlight in order to limit the amount of data in flight.
        val remoteRequests = new ArrayBuffer[FetchRequest]
    
        // Tracks total number of blocks (including zero sized blocks)
        var totalBlocks = 0
        for ((address, blockInfos) <- blocksByAddress) {
          totalBlocks += blockInfos.size
          //要获取的数据在本地
          if (address.executorId == blockManager.blockManagerId.executorId) {
            // Filter out zero-sized blocks
            //记录数据在本地的BlockID
            localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
            numBlocksToFetch += localBlocks.size
          } else {
           //数据不在本地时
            val iterator = blockInfos.iterator
            var curRequestSize = 0L
            var curBlocks = new ArrayBuffer[(BlockId, Long)]
            while (iterator.hasNext) {
              val (blockId, size) = iterator.next()
              // Skip empty blocks
              if (size > 0) {
                curBlocks += ((blockId, size))
                //记录数据在远程机器上的BlockID
                remoteBlocks += blockId
                numBlocksToFetch += 1
                curRequestSize += size
              } else if (size < 0) {
                throw new BlockException(blockId, "Negative block size " + size)
              }
              if (curRequestSize >= targetRequestSize) {
                // Add this FetchRequest
                remoteRequests += new FetchRequest(address, curBlocks)
                curBlocks = new ArrayBuffer[(BlockId, Long)]
                logDebug(s"Creating fetch request of $curRequestSize at $address")
                curRequestSize = 0
              }
            }
            // Add in the final request
            if (curBlocks.nonEmpty) {
              remoteRequests += new FetchRequest(address, curBlocks)
            }
          }
        }
        logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
        remoteRequests
      }

    ShuffleBlockFetcherIterator#fetchLocalBlocks()

    fetchLocalBlocks()方法进行本地Block的读取。调用的是BlockManager的getBlockData方法。其源代码例如以下:

    private[this] def fetchLocalBlocks() {
        val iter = localBlocks.iterator
        while (iter.hasNext) {
          val blockId = iter.next()
          try {
            //调用BlockManager的getBlockData方法
            val buf = blockManager.getBlockData(blockId)
            shuffleMetrics.incLocalBlocksFetched(1)
            shuffleMetrics.incLocalBytesRead(buf.size)
            buf.retain()
            results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
          } catch {
            case e: Exception =>
              // If we see an exception, stop immediately.
              logError(s"Error occurred while fetching local blocks", e)
              results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
              return
          }
        }
      }

    跳转到BlockManager的getBlockData方法。能够看到其源代码例如以下:

    override def getBlockData(blockId: BlockId): ManagedBuffer = {
              if (blockId.isShuffle) {   
    //先调用的是ShuffleManager的shuffleBlockResolver方法。得到ShuffleBlockResolver
    //然后再调用其getBlockData方法   shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
              } else {
                val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
                  .asInstanceOf[Option[ByteBuffer]]
                if (blockBytesOpt.isDefined) {
                  val buffer = blockBytesOpt.get
            new NioManagedBuffer(buffer)
          } else {
            throw new BlockNotFoundException(blockId.toString)
          }
        }
      }

    org.apache.spark.shuffle.hash.ShuffleManager#shuffleBlockResolver()方法获取相应的ShuffleBlockResolver,假设是Hash Shuffle,则
    是org.apache.spark.shuffle.FileShuffleBlockResolver,假设是Sort Shuffle则org.apache.spark.shuffle.IndexShuffleBlockResolver。

    然后调用相应ShuffleBlockResolver的getBlockData方法,返回相应的FileSegment。


    FileShuffleBlockResolver#getBlockData方法源代码例如以下:

    override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
        //相应Hash Shuffle中的Shuffle Consolidate Files机制生成的文件
        if (consolidateShuffleFiles) { 
          // Search all file groups associated with this shuffle.
          val shuffleState = shuffleStates(blockId.shuffleId)
          val iter = shuffleState.allFileGroups.iterator
          while (iter.hasNext) {
            val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
            if (segmentOpt.isDefined) {
              val segment = segmentOpt.get
              return new FileSegmentManagedBuffer(
                transportConf, segment.file, segment.offset, segment.length)
            }
          }
          throw new IllegalStateException("Failed to find shuffle block: " + blockId)
        } else {
          //普通的Hash Shuffle机制生成的文件
          val file = blockManager.diskBlockManager.getFile(blockId)
          new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
        }
      }

    IndexShuffleBlockResolver#getBlockData方法源代码例如以下:

    override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
        // The block is actually going to be a range of a single map output file for this map, so
        // find out the consolidated file, then the offset within that from our index
        //使用shuffleId和mapId,获取相应索引文件
        val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
    
        val in = new DataInputStream(new FileInputStream(indexFile))
        try {
          //定位到本次Block相应的数据位置
          ByteStreams.skipFully(in, blockId.reduceId * 8)
          //数据起始位置
          val offset = in.readLong()
          //数据结束位置
          val nextOffset = in.readLong()
          //返回FileSegment
          new FileSegmentManagedBuffer(
            transportConf,
            getDataFile(blockId.shuffleId, blockId.mapId),
            offset,
            nextOffset - offset)
        } finally {
          in.close()
        }
      }

    ShuffleBlockFetcherIterator#sendRequest()

    sendRequest()方法用于从远程机器上获取数据

     private[this] def sendRequest(req: FetchRequest) {
        logDebug("Sending request for %d blocks (%s) from %s".format(
          req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
        bytesInFlight += req.size
    
        // so we can look up the size of each blockID
        val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
        val blockIds = req.blocks.map(_._1.toString)
    
        val address = req.address
        //使用ShuffleClient的fetchBlocks方法获取数据
        //有两种ShuffleClient。各自是ExternalShuffleClient和BlockTransferService
        //默觉得BlockTransferService
        shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
          new BlockFetchingListener {
            override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
              // Only add the buffer to results queue if the iterator is not zombie,
              // i.e. cleanup() has not been called yet.
              if (!isZombie) {
                // Increment the ref count because we need to pass this to a different thread.
                // This needs to be released after use.
                buf.retain()
                results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
                shuffleMetrics.incRemoteBytesRead(buf.size)
                shuffleMetrics.incRemoteBlocksFetched(1)
              }
              logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
            }
    
            override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
              logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
              results.put(new FailureFetchResult(BlockId(blockId), address, e))
            }
          }
        )
      }
    

    通过上面的代码能够看到,代码使用的是shuffleClient.fetchBlocks进行远程Block数据的获取。org.apache.spark.network.shuffle.ShuffleClient有两个子类,各自是ExternalShuffleClient和BlockTransferService,而org.apache.spark.network.shuffle.BlockTransferService又有两个子类。各自是NettyBlockTransferService和NioBlockTransferService,shuffleClient 对象在 org.apache.spark.storage.BlockManager定义,其源代码例如以下:

    // org.apache.spark.storage.BlockManager中定义的shuffleClient 
     private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
        //使用ExternalShuffleClient获取远程Block数据
        val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
        new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(),
          securityManager.isSaslEncryptionEnabled())
      } else {
        //使用NettyBlockTransferService或NioBlockTransferService获取远程Block数据
        blockTransferService
      }

    代码中的blockTransferService在SparkEnv中被初始化,详细例如以下:

     //org.apache.spark.SparkEnv中初始化blockTransferService 
     val blockTransferService =
          conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
            case "netty" =>
              new NettyBlockTransferService(conf, securityManager, numUsableCores)
            case "nio" =>
              logWarning("NIO-based block transfer service is deprecated, " +
                "and will be removed in Spark 1.6.0.")
              new NioBlockTransferService(conf, securityManager)
          }
    
  • 相关阅读:
    文本标记
    第一个HTML文档
    HTML入门
    bootstrap fileinput 文件上传
    DPDK rte_hash 简述
    glib学习笔记-基本知识
    linux常用网络命令
    libevent学习过程
    C语言 singleton模式
    oracle命令行导出、导入dmp文件
  • 原文地址:https://www.cnblogs.com/zsychanpin/p/7226064.html
Copyright © 2020-2023  润新知