spark中要将计算结果取回driver,有两种方式:collect和take,这两种方式有什么差别?来看代码:
org.apache.spark.rdd.RDD
/** * Return an array that contains all of the elements in this RDD. * * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. */ def collect(): Array[T] = withScope { val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray) Array.concat(results: _*) } /** * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy * the limit. * * @note This method should only be used if the resulting array is expected to be small, as * all the data is loaded into the driver's memory. * * @note Due to complications in the internal implementation, this method will raise * an exception if called on an RDD of `Nothing` or `Null`. */ def take(num: Int): Array[T] = withScope { val scaleUpFactor = Math.max(conf.getInt("spark.rdd.limit.scaleUpFactor", 4), 2) if (num == 0) { new Array[T](0) } else { val buf = new ArrayBuffer[T] val totalParts = this.partitions.length var partsScanned = 0 while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate // it by 50%. We also cap the estimation in the end. if (buf.isEmpty) { numPartsToTry = partsScanned * scaleUpFactor } else { // the left side of max is >=1 whenever partsScanned >= 2 numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor) } } val left = num - buf.size val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) partsScanned += p.size } buf.toArray } }
可见collect是直接计算所有结果,然后将每个partition的结果变成array,然后再合并成一个array;
而take的实现就要复杂一些,它会首先计算1个partition,然后根据结果的数量推断出还需要计算几个分区,然后再计算这几个分区,然后再看结果够不够,这是一个迭代的过程,计算越简单或者take数量越少,越有可能在前边的迭代中满足条件返回;