Shuffle Write 请看 Shuffle Write解析。
本文将讲解shuffle Reduce部分,shuffle的下游Stage的第一个rdd是ShuffleRDD,通过其compute方法来获取上游Stage Shuffle Write溢写到磁盘文件数据的一个迭代器:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
.read()
.asInstanceOf[Iterator[(K, C)]]
}从SparkEnv中获取shuffleManager(这里是SortShuffleManager),通过manager获取Reader并调用其read方法来得到一个迭代器。
override def getReader[K, C]( handle: ShuffleHandle, startPartition: Int, endPartition: Int, context: TaskContext): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
}getReader方法实例化了一个BlockStoreShuffleReader,参数有需要获取分区对应的partitionId,看看起read方法:
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager, // 获取存储数据位置的元数据
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // 每次远程请求传输的最大大小
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // 用压缩加密来包装流
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance() // 对每个流生成K/V迭代器
val recordIter = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
} // 每条记录读取后更新任务度量
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() // 生成完整的迭代器
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // 在map端已经聚合一次了
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else { // 只在reduce端聚合
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
} // 若需要全局排序
dep.keyOrdering match { case Some(keyOrd: Ordering[K]) =>
val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None =>
aggregatedIter
}
}首先实例化了ShuffleBlockFetcherIterator对象,其中一个参数:
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)
该方法获取reduce端数据的来源的元数据,返回的是 Seq[(BlockManagerId, Seq[(BlockId, Long)])],即数据是来自于哪个节点的哪些block的,并且block的数据大小是多少,看看getMapSizesByExecutorId是怎么实现的:
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") // 获取元数据信息
val statuses = getStatuses(shuffleId) // 转换格式并得到指定partition的元数据信息
statuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
}
}传入shuffleId获取对应shuffle的所有元数据信息
转换格式并获取指定partition的元数据
跟进getStatuses:
private def getStatuses(shuffleId: Int): Array[MapStatus] = { // 直接从mapStatuses中获取
val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
val startTime = System.currentTimeMillis var fetchedStatuses: Array[MapStatus] = null
...... if (fetchedStatuses == null) { // We won the race to fetch the statuses; do so
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint) // This try-finally prevents hangs due to timeouts:
try { // 从远程获取元数据
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) // 反序列化
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations") // 加入mapStatus
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
.....
}
} else { return statuses
}
}若能从mapStatuses获取到则直接返回,若不能则向mapOutputTrackerMaster通信发送GetMapOutputStatuses消息来获取元数据。
我们知道一个Executor对应一个CoarseGrainedExecutorBackend,构建CoarseGrainedExecutorBackend的时候会创建一个SparkEnv,创建SparkEnv的时候会创建一个mapOutputTracker,即mapOutputTracker和Executor一一对应,也就是每一个Executor都有一个mapOutputTracker来维护元数据信息。
这里的mapStatuses就是mapOutputTracker保存元数据信息的,mapOutputTracker和Executor一一对应,在该Executor上完成的Shuffle Write的元数据信息都会保存在其mapStatus里面,另外通过远程获取的其他Executor上完成的Shuffle Write的元数据信息也会在当前的mapStatuses中保存。
Executor对应的是mapOutputTrackerWorker,而Driver对应的是mapOutputTrackerMaster,两者都是在实例化SparkEnv的时候创建的,每个在Executor上完成的Shuffle Task的结果都会注册到driver端的mapOutputTrackerMaster中,即driver端的mapOutputTrackerMaster的mapStatuses保存这所有元数据信息,所以当一个Executor上的任务需要获取一个shuffle的输出时,会先在自己的mapStatuses中查找,找不到再和mapOutputTrackerMaster通信获取元数据。
mapOutputTrackerMaster收到消息后的处理逻辑:
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))调用了tracker的post方法:
def post(message: GetMapOutputMessage): Unit = {
mapOutputRequests.offer(message)
}将该Message加入了mapOutputRequests中,mapOutputRequests是一个链式阻塞队列,在mapOutputTrackerMaster初始化的时候专门启动了一个线程池来执行这些请求:
private val threadpool: ThreadPoolExecutor = {
val numThreads = conf.getInt("spark.shuffle.mapOutput.dispatcher.numThreads", 8)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "map-output-dispatcher") for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
}
pool
}看看线程处理类MessageLoop的run方法是怎么定义的:
private class MessageLoop extends Runnable { override def run(): Unit = { try { while (true) { try { // 取出一个GetMapOutputMessage
val data = mapOutputRequests.take() if (data == PoisonPill) { // Put PoisonPill back so that other MessageLoops can see it.
mapOutputRequests.offer(PoisonPill) return
}
val context = data.context
val shuffleId = data.shuffleId
val hostPort = context.senderAddress.hostPort logDebug("Handling request to send map output locations for shuffle " + shuffleId + " to " + hostPort)
// 通过shuffleId获取对应序列化后的元数据信息
val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId) // 返回数据
context.reply(mapOutputStatuses)
} catch { case NonFatal(e) => logError(e.getMessage, e)
}
}
} catch { case ie: InterruptedException => // exit
}
}
}通过shuffleId获取对应序列化后的元数据信息并返回,具体看看getSerializedMapOutputStatuses的实现:
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = { var statuses: Array[MapStatus] = null
var retBytes: Array[Byte] = null
var epochGotten: Long = -1
// 从cache中检索出MapStatus,若没有则从mapStatuses中获取
def checkCachedStatuses(): Boolean = {
epochLock.synchronized { if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
clearCachedBroadcast()
cacheEpoch = epoch
}
cachedSerializedStatuses.get(shuffleId) match { case Some(bytes) =>
retBytes = bytes true
case None =>
logDebug("cached status not found for : " + shuffleId)
statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
epochGotten = epoch false
}
}
} if (checkCachedStatuses()) return retBytes var shuffleIdLock = shuffleIdLocks.get(shuffleId) if (null == shuffleIdLock) {
val newLock = new Object() // in general, this condition should be false - but good to be paranoid
val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
shuffleIdLock = if (null != prevLock) prevLock else newLock
} // synchronize so we only serialize/broadcast it once since multiple threads call
// in parallel
shuffleIdLock.synchronized { if (checkCachedStatuses()) return retBytes // 序列化statues
val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
isLocal, minSizeForBroadcast)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized { if (epoch == epochGotten) {
cachedSerializedStatuses(shuffleId) = bytes if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
} else {
logInfo("Epoch changed, not caching!")
removeBroadcast(bcast)
}
}
bytes
}
}大体思路是先从缓存中获取元数据(MapStatuses),获取到直接返回,若没有则从mapStatuses获取,获取到后将其序列化后返回,随后返回给mapOutputTrackerWorker(刚才与之通信的节点),mapOutputTracker收到回复后又将元数据序列化并加入当前Executor的mapStatuses中。
再回到getMapSizesByExecutorId方法中,getStatuses得到shuffleID对应的所有的元数据信息后,通过convertMapStatuses方法将获得的元数据信息转化成形如Seq[(BlockManagerId, Seq[(BlockId, Long)])]格式的位置信息,用来读取指定的分区的数据:
private def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) // 存储指定partition的元数据
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] for ((status, mapId) <- statuses.zipWithIndex) { if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage) throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
} else { for (part <- startPartition until endPartition) {
splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=
((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))
}
}
}
splitsByAddress.toSeq
}这里的参数statuses:Array[MapStatus]是前面获取的上游stage所有的shuffle Write 文件的元数据,并且是按map端的partitionId排序的,通过zipWithIndex将元素和这个元素在数组中的ID(索引号)组合成键/值对,这里的索引号即是map端的partitionId,再根据shuffleId、mapPartitionId、reducePartitionId来构建ShuffleBlockId(在map端的ShuffleBlockId构建中的reducePartitionId始终是0,因为一个ShuffleMapTask就一个Block,而这里加入的真正的reducePartitionId在后面通过index文件获取对应reduce端partition偏移量的时候需要用到),并估算得到对应数据的大小,因为后面获取远程数据的时候需要限制大小,最后返回位置信息。
至此mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition)方法完成,返回了指定分区对应的元数据MapStatus信息。
在初始化对象ShuffleBlockFetcherIterator的时候调用了其初始化方法initialize():
private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup()) // 区分local blocks和remote blocks并返回远程请求FetchRequest
val remoteRequests = splitLocalRemoteBlocks() // 将远程请求随机的加入到fetchRequests队列中
fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight + ", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight) // 从fetchRequests取出远程请求,并使用sendRequest方法发送请求
fetchUpToMaxBytes()
val numFetches = remoteRequests.size - fetchRequests.size logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) // 获取本地blocks
fetchLocalBlocks()
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}区分local blocks和remote blocks,并返回远程请求FetchRequest加入到fetchRequests队列中
从fetchRequests取出远程请求,并使用sendRequest方法发送请求,获取远程数据
获取本地blocks
先看是怎么区分local blocks和remote blocks的:
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { // 将一次能获取的数据最大大小/5,目的是增加并行度,最大为5个并行度
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // 存储远程请求的数组
val remoteRequests = new ArrayBuffer[FetchRequest] // Tracks total number of blocks (including zero sized blocks)
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size // 若block所在executor就是当前executor,则判断为本地,否则为远程
if (address.executorId == blockManager.blockManagerId.executorId) { // 过滤掉大小为0的blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator var curRequestSize = 0L var curBlocks = new ArrayBuffer[(BlockId, Long)] while (iterator.hasNext) {
val (blockId, size) = iterator.next() // Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
} else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size)
} // 当请求大小超过了限制,则创建一个FetchRequest并加入到remoteRequests中
if (curRequestSize >= targetRequestSize) { // Add this FetchRequest
remoteRequests += new FetchRequest(address, curBlocks)
curBlocks = new ArrayBuffer[(BlockId, Long)]
logDebug(s"Creating fetch request of $curRequestSize at $address")
curRequestSize = 0
}
} // 将剩余的blocks创建一个FetchRequest并加入到remoteRequests中
if (curBlocks.nonEmpty) {
remoteRequests += new FetchRequest(address, curBlocks)
}
}
}
logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
remoteRequests
}为了增加在远程节点获取数据的并行度,将一个请求的大小限制除以5作为最终的大小限制,即每次最多启动5个线程去最多5个节点上读取数据
判断是否是本地blocks的条件是block所在的executor和当前executor是否是同一个
遍历远程数据节点(Executor节点)的blocks,在一个节点上的请求数据超过大小限制则构建一个FetchRequest并加入到remoteRequests中,最后返回远程请求remoteRequests,这里的FetchRequest是对一个请求数据的包装,包括地址和blockId及大小
区分完local remote blocks后加入到了队列fetchRequests中,并调用fetchUpToMaxBytes()来获取远程数据:
private def fetchUpToMaxBytes(): Unit = { // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 ||
(reqsInFlight + 1 <= maxReqsInFlight &&
bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {
sendRequest(fetchRequests.dequeue())
}
}从fetchRequests中取出FetchRequest,并调用了sendRequest方法:
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
bytesInFlight += req.size
reqsInFlight += 1
// 转成map Map[blockId,size]
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address // 通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据
shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { // 将结果保存到results中
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { // Only add the buffer to results queue if the iterator is not zombie,
// i.e. cleanup() has not been called yet.
ShuffleBlockFetcherIterator.this.synchronized { if (!isZombie) { // Increment the ref count because we need to pass this to a different thread.
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
} override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
}
)
}通过shuffleClient的fetchBlocks方法来获取对应远程节点上的数据,默认是通过NettyBlockTransferService的fetchBlocks方法实现的,不管是成功还是失败都将构建SuccessFetchResult & FailureFetchResult 结果放入results中。
获取完远程的数据接着通过fetchLocalBlocks()方法来获取本地的blocks信息:
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator while (iter.hasNext) {
val blockId = iter.next() try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
} catch { case e: Exception => // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return
}
}
}迭代需要获取的block,直接从blockManager中获取数据,并通过结果数据构建SuccessFetchResult或者FailureFetchResult放入results中,看看在blockManager.getBlockData(blockId)的实现:
override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) {
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
getLocalBytes(blockId) match { case Some(buffer) => new BlockManagerManagedBuffer(blockInfoManager, blockId, buffer) case None => // If this block manager receives a request for a block that it doesn't have then it's
// likely that the master has outdated block statuses for this block. Therefore, we send
// an RPC so that this block is marked as being unavailable from this block manager.
reportBlockStatus(blockId, BlockStatus.empty) throw new BlockNotFoundException(blockId.toString)
}
}
}再看看getBlockData方法:
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { // 根据ShuffleID和MapID获取索引文件
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
val in = new DataInputStream(new FileInputStream(indexFile)) try { // 跳到对应Block的数据区
ByteStreams.skipFully(in, blockId.reduceId * 8) // partition对应的开始offset
val offset = in.readLong() // partition对应的结束offset
val nextOffset = in.readLong() new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
} finally { in.close()
}
}根据shuffleId和mapId获取index文件,并创建一个读文件的文件流,根据block的reduceId(上面获取对应partition元数据的时候提到过)跳过对应的Block的数据区,先后获取开始和结束的offset,然后在数据文件中读取数据。
得到所有数据结果result后,再回到read()方法中:
override def read(): Iterator[Product2[K, C]] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager, // 与mapOutputTrackerMaster通信获取存储数据位置的元数据
mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), // 每次传输的最大大小
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) // 用压缩加密来包装流
val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
serializerManager.wrapStream(blockId, inputStream)
}
val serializerInstance = dep.serializer.newInstance() // 对每个流生成K/V迭代器
val recordIter = wrappedStreams.flatMap { wrappedStream =>
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
} // 每条记录读取后更新任务度量
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() // 生成完整的迭代器
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
record
},
context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { // 在map端已经聚合一次了
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else { // 只在reduce端聚合
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
} // 若需要全局排序
dep.keyOrdering match { case Some(keyOrd: Ordering[K]) =>
val sorter = new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) case None =>
aggregatedIter
}
}这里的ShuffleBlockFetcherIterator继承了Iterator,results可以被迭代,在其next()方法中将FetchResult以(blockId,inputStream)的形式返回:
case SuccessFetchResult(blockId, address, _, buf, _) => try {
(result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
} catch { case NonFatal(t) =>
throwFetchFailedException(blockId, address, t)
}在read()方法的后半部分会进行聚合和排序,和Shuffle Write部分很类似,这里大致描述一下。
在需要聚合的前提下,有map端聚合的时候执行combineCombinersByKey,没有则执行combineValuesByKey,但最终都调用了ExternalAppendOnlyMap的insertAll(iter)方法:
def combineCombinersByKey(
iter: Iterator[_ <: Product2[K, C]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}def combineValuesByKey(
iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
combiners.insertAll(iter)
updateMetrics(context, combiners)
combiners.iterator
}def insertAll(entries: Iterator[Product2[K, V]]): Unit = { if (currentMap == null) { throw new IllegalStateException( "Cannot insert new elements into a map after calling iterator")
} // An update function for the map that we reuse across entries to avoid allocating
// a new closure each time
var curEntry: Product2[K, V] = null
val update: (Boolean, C) => C = (hadVal, oldVal) => { if (hadVal) mergeValue(oldVal, curEntry._2) else createCombiner(curEntry._2)
} while (entries.hasNext) {
curEntry = entries.next()
val estimatedSize = currentMap.estimateSize() if (estimatedSize > _peakMemoryUsedBytes) {
_peakMemoryUsedBytes = estimatedSize
} if (maybeSpill(currentMap, estimatedSize)) {
currentMap = new SizeTrackingAppendOnlyMap[K, C]
}
currentMap.changeValue(curEntry._1, update)
addElementsRead()
}
}在里面的迭代最终都会调用上面提到的ShuffleBlockFetcherIterator的next方法来获取数据。
每次update&insert也会估算currentMap的大小,并判断是否需要溢写到磁盘文件,若需要则将map中的数据根据定义的keyComparator对key进行排序后返回一个迭代器,然后写到一个临时的磁盘文件,然后新建一个map来放新的数据。
执行完combiners[ExternalAppendOnlyMap]的insertAll后,调用其iterator来返回一个代表一个完整partition数据(内存及spillFile)的迭代器:
override def iterator: Iterator[(K, C)] = { if (currentMap == null) { throw new IllegalStateException( "ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
} if (spilledMaps.isEmpty) {
CompletionIterator[(K, C), Iterator[(K, C)]](
destructiveIterator(currentMap.iterator), freeCurrentMap())
} else { new ExternalIterator()
}
}跟进ExternalIterator类的实例化:
// A queue that maintains a buffer for each stream we are currently merging
// This queue maintains the invariant that it only contains non-empty buffers
private val mergeHeap = new mutable.PriorityQueue[StreamBuffer] // Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap()) private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)
inputStreams.foreach { it =>
val kcPairs = new ArrayBuffer[(K, C)]
readNextHashCode(it, kcPairs) if (kcPairs.length > 0) {
mergeHeap.enqueue(new StreamBuffer(it, kcPairs))
}
}将currentMap中的数据经过排序后和spillFile数据的iterator组合在一起得到inputStreams ,迭代这个inputStreams ,将所有数据都保存在mergeHeadp中,在ExternalIterator方法的next()方法中将被访问到。
最后若需要对数据进行全局的排序,则通过只有排序参数的ExternalSorter的insertAll方法来进行排序,和Shuffle Write一样的这里就不细讲了。
最终返回一个指定partition所有数据的一个迭代器。
作者:BIGUFO
链接:https://www.jianshu.com/p/50278b0a0050
随时随地看视频