RDD的多个Partition由不同Task处理,Task分为shuffleMapTask和resultTask
1.Task解析
Task是计算的基本单位,一个Task处理RDD的一个Partition,Task运行在Executor上,Executor位于CoarseGrainedExecutorBackend中
在Spark中可以根据Task所处Stage的位置将其分为两类:
shuffleMapTask指Task所处的Stage不是最后一个Stage
resultTask指Task所处的Stage是最后一个Stage
除了最后一个Stage的Task为resultTask,其他都为shuffleMapTask
2.计算过程解析
RDD进行计算之前,Driver给Executor发送消息让它启动Task,Executor启动成功后,返回成功信息给Driver,详细步骤如下:
1.Driver中CoarseGrainedSchedulerBackend给CoarseGrainedExecutorBackend发送LaunchTask消息
receive方法如下:
override def receive: PartialFunction[Any, Unit] = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
try {
executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
} catch {
case NonFatal(e) =>
exitExecutor(1, "Unable to create executor due to " + e.getMessage, e)
}
case RegisterExecutorFailed(message) =>
exitExecutor(1, "Slave registration failed: " + message)
case LaunchTask(data) =>
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
//反序列化TaskDescription
val taskDesc = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc)
}
case KillTask(taskId, _, interruptThread, reason) =>
if (executor == null) {
exitExecutor(1, "Received KillTask command but executor was null")
} else {
executor.killTask(taskId, interruptThread, reason)
}
case StopExecutor =>
stopping.set(true)
logInfo("Driver commanded a shutdown")
// Cannot shutdown here because an ack may need to be sent back to the caller. So send
// a message to self to actually do the shutdown.
self.send(Shutdown)
case Shutdown =>
stopping.set(true)
new Thread("CoarseGrainedExecutorBackend-stop-executor") {
override def run(): Unit = {
// executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.
// However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to
// stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).
// Therefore, we put this line in a new thread.
executor.stop()
}
}.start()
}
decode方法如下:
def decode(byteBuffer: ByteBuffer): TaskDescription = {
val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
val taskId = dataIn.readLong()
val attemptNumber = dataIn.readInt()
val executorId = dataIn.readUTF()
val name = dataIn.readUTF()
val index = dataIn.readInt()
// Read files.读取文件
val taskFiles = deserializeStringLongMap(dataIn)
// Read jars.读取jars包
val taskJars = deserializeStringLongMap(dataIn)
// Read properties.读取配置
val properties = new Properties()
val numProperties = dataIn.readInt()
for (i <- 0 until numProperties) {
val key = dataIn.readUTF()
val valueLength = dataIn.readInt()
val valueBytes = new Array[Byte](valueLength)
dataIn.readFully(valueBytes)
properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8))
}
// Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
//创建一个子缓冲区用于序列化任务将其变成自己的缓冲区(稍后反序列化)
val serializedTask = byteBuffer.slice()
new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
properties, serializedTask)
}
2.Executor调用launchTask运行Task
3.launchTask方法创建一个TaskRunner实例在threadPool中运行具体的Task,源码如下:
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
//创建TaskRunner对象
val tr = new TaskRunner(context, taskDescription)
//将创建的TaskRunner对象放入运行任务的堆栈中
runningTasks.put(taskDescription.taskId, tr)
//从线程池中分配线程给TaskRunner
threadPool.execute(tr)
}
ThreadRunner实现了Runnable接口,run方法源码如下:
override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
val threadMXBean = ManagementFactory.getThreadMXBean
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
//告诉Driver处于RUNNING状态
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()
try {
// Must be set before updateDependencies() is called, in case fetching dependencies
// requires access to properties contained within (e.g. for access control).
Executor.taskDeserializationProps.set(taskDescription.properties)
updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
//反序列化task
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
val killReason = reasonIfKilled
if (killReason.isDefined) {
// Throw an exception rather than returning, because returning within a try{} block
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
// for the task.
throw new TaskKilledException(killReason.get)
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
// Run the actual task and measure its runtime.
//运行实际任务并测量其运行时间
//计算开始时间
taskStart = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
var threwException = true
//运行Task的run方法
val value = try {
val res = task.run(
taskAttemptId = taskId,
attemptNumber = taskDescription.attemptNumber,
metricsSystem = env.metricsSystem)
threwException = false
res
} finally {
val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()
if (freedMemory > 0 && !threwException) {
val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
if (conf.getBoolean("spark.unsafe.exceptionOnMemoryLeak", false)) {
throw new SparkException(errMsg)
} else {
logWarning(errMsg)
}
}
if (releasedLocks.nonEmpty && !threwException) {
val errMsg =
s"${releasedLocks.size} block locks were not released by TID = $taskId:
" +
releasedLocks.mkString("[", ", ", "]")
if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false)) {
throw new SparkException(errMsg)
} else {
logInfo(errMsg)
}
}
}
task.context.fetchFailed.foreach { fetchFailure =>
// uh-oh. it appears the user code has caught the fetch-failure without throwing any
// other exceptions. Its *possible* this is what the user meant to do (though highly
// unlikely). So we will log an error and keep going.
logError(s"TID ${taskId} completed successfully though internally it encountered " +
s"unrecoverable fetch failures! Most likely this means user code is incorrectly " +
s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure)
}
//计算完成时间
val taskFinish = System.currentTimeMillis()
val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
.....
task.run方法调用了runTask方法,源码如下:
/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task along with updates of Accumulators.
*/
final def run(
taskAttemptId: Long,
attemptNumber: Int,
metricsSystem: MetricsSystem): T = {
SparkEnv.get.blockManager.registerTask(taskAttemptId)
context = new TaskContextImpl(
stageId,
partitionId,
taskAttemptId,
attemptNumber,
taskMemoryManager,
localProperties,
metricsSystem,
metrics)
TaskContext.setTaskContext(context)
taskThread = Thread.currentThread()
if (_reasonIfKilled != null) {
kill(interruptThread = false, _reasonIfKilled)
}
new CallerContext(
"TASK",
SparkEnv.get.conf.get(APP_CALLER_CONTEXT),
appId,
appAttemptId,
jobId,
Option(stageId),
Option(stageAttemptId),
Option(taskAttemptId),
Option(attemptNumber)).setCurrentContext()
try {
runTask(context)
} catch {
case e: Throwable =>
// Catch all errors; run task failure callbacks, and rethrow the exception.
try {
context.markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
}
context.markTaskCompleted(Some(e))
throw e
} finally {
try {
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
// one is no-op.
context.markTaskCompleted(None)
} finally {
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
}
}
runTask方法有两种实现,分别为ShuffleMapTask和ResultTask:
1.ShuffleMapTask
该类的runTask方法如下:
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
//使用广播变量反序列化RDD
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
//创建序列化器
val ser = SparkEnv.get.closureSerializer.newInstance()
//反序列化得到rdd和依赖关系
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
//创建ShuffleWriter对象,将计算结果写入shuffleManager
var writer: ShuffleWriter[Any, Any] = null
try {
//实例化shuffleManager
val manager = SparkEnv.get.shuffleManager
writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
//写入过程
writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
writer.stop(success = true).get
} catch {
case e: Exception =>
try {
if (writer != null) {
writer.stop(success = false)
}
} catch {
case e: Exception =>
log.debug("Could not stop writer", e)
}
throw e
}
}
首先会反序列化RDD和依赖关系,最后调用rdd.iterator方法计算
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
}
computeOrReadCheckpoint方法如下:
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
private[spark] def computeOrReadCheckpoint(split: Partition, context: TaskContext): Iterator[T] =
{
if (isCheckpointedAndMaterialized) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
}
该方法会在每个RDD里面进行重写,例如MapPartitionsRDD:
/**
* An RDD that applies the provided function to every partition of the parent RDD.
*/
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
var prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext): Iterator[U] =
f(context, split.index, firstParent[T].iterator(split, context))
override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
其中f函数就是我们创建MapPartitionsRDD时输入的操作函数,在计算具体的Partition之后,通过shuffleManager得到的shuffleWriter将当前计算结果写入具体文件中,操作完成之后将MapStatus发送给Driver端的DAGScheduler的MapOutputTracker
2.ResultTask
Driver端的DAGScheduler的MapOutputTracker将shuffleMapTask执行的结果交给ResultTask,然后根据前面Stage的执行结果进行shuffle后产生最后结果
源码如下:
override def runTask(context: TaskContext): U = {
// Deserialize the RDD and the func using the broadcast variables.
//使用广播变量反序列化RDD
val threadMXBean = ManagementFactory.getThreadMXBean
val deserializeStartTime = System.currentTimeMillis()
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
//创建序列化器
val ser = SparkEnv.get.closureSerializer.newInstance()
//反序列化RDD和func处理函数,通过func函数计算最后结果
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
func(context, rdd.iterator(partition, context))
}